File size: 4,457 Bytes
41157bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25007bc
41157bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---
license: apache-2.0
base_model: Qwen/Qwen2.5-VL-3B-Instruct
datasets:
  - TESS-Computer/quickdraw-circles
tags:
  - trajectory-prediction
  - diffusion-transformer
  - vision-language
  - robotics
  - drawing
pipeline_tag: image-to-image
---

# Qwen-DiT-Draw

A Vision-Language Model with Diffusion Transformer head for trajectory prediction. Given an image and instruction, the model predicts drawing trajectories.

**Architecture:** Frozen Qwen2.5-VL-3B backbone + trainable DiT action head (36.7M params)

## Model Details

- **Base Model:** [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
- **Training Data:** [TESS-Computer/quickdraw-circles](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles) (21k circle drawings)
- **Architecture:** GR00T-style chunked prediction with flow matching
- **Trainable Parameters:** 36.7M (DiT head only, VLM frozen)
- **Chunk Size:** 16 points per chunk
- **Output:** (x, y, state) where state > 0.5 indicates stop signal

## Usage

```python
import torch
from PIL import Image
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info

# You need the model code from: https://github.com/HusseinLezzaik/Qwen-DiT-Draw
from src.model import Qwen2_5_VL_Draw, TrajectoryConfig

# Load model
config = TrajectoryConfig(chunk_size=16, dit_hidden_size=512, dit_num_layers=6)
model = Qwen2_5_VL_Draw(
    model_id="Qwen/Qwen2.5-VL-3B-Instruct",
    config=config,
    freeze_backbone=True,
    dtype=torch.bfloat16,
)

# Load trained weights
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="TESS-Computer/qwen-dit-draw", filename="trajectory_head.pt")
model.trajectory_head.load_state_dict(torch.load(weights_path, weights_only=True))
model = model.to("cuda").eval()

# Load processor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

# Create input
image = Image.new("RGB", (512, 512), "white")  # White canvas
instruction = "draw a circle"

messages = [{
    "role": "user",
    "content": [
        {"type": "image", "image": image, "min_pixels": 200704, "max_pixels": 401408},
        {"type": "text", "text": instruction},
    ],
}]

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _, _ = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(text=[text], images=image_inputs, return_tensors="pt")
inputs = {k: v.to("cuda") if torch.is_tensor(v) else v for k, v in inputs.items()}

# Predict trajectory chunk
with torch.no_grad():
    chunk = model.predict_chunk(**inputs)

chunk = chunk[0].float().cpu().numpy()  # (16, 3) - (x, y, state)
print(f"Predicted {len(chunk)} points")
for i, (x, y, state) in enumerate(chunk):
    print(f"  Point {i}: ({x:.3f}, {y:.3f}), stop={state > 0.5}")
```

## Multi-Chunk Inference (Full Drawing)

For complete drawings, use visual feedback loop:

```python
from PIL import ImageDraw

canvas = Image.new("RGB", (512, 512), "white")
all_points = []
max_chunks = 10

for chunk_idx in range(max_chunks):
    # Prepare inputs with current canvas
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": canvas, "min_pixels": 200704, "max_pixels": 401408},
            {"type": "text", "text": "draw a circle"},
        ],
    }]
    # ... process and predict ...

    # Draw on canvas (use BLACK lines to match training!)
    draw = ImageDraw.Draw(canvas)
    for i in range(1, len(chunk)):
        x1, y1 = int(chunk[i-1][0] * 512), int(chunk[i-1][1] * 512)
        x2, y2 = int(chunk[i][0] * 512), int(chunk[i][1] * 512)
        draw.line([(x1, y1), (x2, y2)], fill='black', width=2)

        if chunk[i][2] > 0.5:  # Stop signal
            break
```

## Training

Trained on Modal H100 for 2 epochs using flow matching loss. See [training code](https://github.com/HusseinLezzaik/Qwen-DiT-Draw).

## Citation

```bibtex
@misc{qwen-dit-draw,
  author = {TESS Computer},
  title = {Qwen-DiT-Draw: VLM + DiT for Trajectory Prediction},
  year = {2025},
  url = {https://huggingface.co/TESS-Computer/qwen-dit-draw}
}
```

## Links

- **Code:** [GitHub - Qwen-DiT-Draw](https://github.com/HusseinLezzaik/Qwen-DiT-Draw)
- **Dataset:** [TESS-Computer/quickdraw-circles](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles)
- **Base Model:** [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)