Multitask DiT Policy — Coffee Capsules (Qwen-Pooled)

Diffusion Transformer (DiT) policy for coffee capsule bin picking, trained with pooled Qwen3-VL-4B-Instruct multimodal conditioning.

Trained using pravsels/multitask_dit_policy.

Training Config

parameter value
dataset villekuosmanen/bin_pick_pack_coffee_capsules
batch_size 16
train_steps 50000
optimizer_lr 2e-5
objective DDIM (num_train_timesteps=100, num_inference_steps=20)
transformer RoPE enabled
multimodal encoder pooled Qwen3-VL-4B-Instruct (output_dim=2048, freeze_backbone=false, lr_multiplier=0.1, gradient_checkpointing=true)
normalization RAMEN (seeded from baseline)
hardware 1x A100 80GB (GCloud a2-ultragpu-1g)
training time 35h 17m
use_amp true

Full config: config/train_coffee_capsules_qwen_pooled_gcloud.yaml in the training repo.

Checkpoints

step loss sha256 manifest
35000 0.00816 365922a3e53e7f0a6c7bb36356eeb43373dac5b5993f1f2efe66a047c5e1002f
45000 0.00727 3771f8e6fdf7810988fde0f44d9c6132db1bc187fc7fbc8ed07c9c6b57621c58
50000 0.00672 e218da1ee6fdd6d15d44d0244356b9836700eae50683028a60b220e4ed12b737

Each checkpoint contains inference-only artifacts: model.safetensors, config.json, ramen_stats.pt.

Verifying hashes

cd checkpoint_<step>
find . -type f \( -name "model.safetensors" -o -name "config.json" -o -name "ramen_stats.pt" \) | sort | xargs sha256sum | sha256sum

Usage

git clone https://github.com/pravsels/multitask_dit_policy
cd multitask_dit_policy
uv sync

uv run -m multitask_dit_policy.examples.inference \
    --checkpoint_dir=<path_to_checkpoint> \
    --dataset_path=<path_to_dataset> \
    --device=cuda

Loss Progression

step loss
5000 0.0151
10000 0.0148
15000 0.0134
20000 0.0176
25000 0.0107
30000 0.0109
35000 0.00816
40000 0.00974
45000 0.00727
50000 0.00672
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Dataset used to train pravsels/multitask-dit-coffee-capsules-qwen-pooled