File size: 4,457 Bytes
41157bf 25007bc 41157bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
---
license: apache-2.0
base_model: Qwen/Qwen2.5-VL-3B-Instruct
datasets:
- TESS-Computer/quickdraw-circles
tags:
- trajectory-prediction
- diffusion-transformer
- vision-language
- robotics
- drawing
pipeline_tag: image-to-image
---
# Qwen-DiT-Draw
A Vision-Language Model with Diffusion Transformer head for trajectory prediction. Given an image and instruction, the model predicts drawing trajectories.
**Architecture:** Frozen Qwen2.5-VL-3B backbone + trainable DiT action head (36.7M params)
## Model Details
- **Base Model:** [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
- **Training Data:** [TESS-Computer/quickdraw-circles](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles) (21k circle drawings)
- **Architecture:** GR00T-style chunked prediction with flow matching
- **Trainable Parameters:** 36.7M (DiT head only, VLM frozen)
- **Chunk Size:** 16 points per chunk
- **Output:** (x, y, state) where state > 0.5 indicates stop signal
## Usage
```python
import torch
from PIL import Image
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
# You need the model code from: https://github.com/HusseinLezzaik/Qwen-DiT-Draw
from src.model import Qwen2_5_VL_Draw, TrajectoryConfig
# Load model
config = TrajectoryConfig(chunk_size=16, dit_hidden_size=512, dit_num_layers=6)
model = Qwen2_5_VL_Draw(
model_id="Qwen/Qwen2.5-VL-3B-Instruct",
config=config,
freeze_backbone=True,
dtype=torch.bfloat16,
)
# Load trained weights
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="TESS-Computer/qwen-dit-draw", filename="trajectory_head.pt")
model.trajectory_head.load_state_dict(torch.load(weights_path, weights_only=True))
model = model.to("cuda").eval()
# Load processor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
# Create input
image = Image.new("RGB", (512, 512), "white") # White canvas
instruction = "draw a circle"
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image, "min_pixels": 200704, "max_pixels": 401408},
{"type": "text", "text": instruction},
],
}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _, _ = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(text=[text], images=image_inputs, return_tensors="pt")
inputs = {k: v.to("cuda") if torch.is_tensor(v) else v for k, v in inputs.items()}
# Predict trajectory chunk
with torch.no_grad():
chunk = model.predict_chunk(**inputs)
chunk = chunk[0].float().cpu().numpy() # (16, 3) - (x, y, state)
print(f"Predicted {len(chunk)} points")
for i, (x, y, state) in enumerate(chunk):
print(f" Point {i}: ({x:.3f}, {y:.3f}), stop={state > 0.5}")
```
## Multi-Chunk Inference (Full Drawing)
For complete drawings, use visual feedback loop:
```python
from PIL import ImageDraw
canvas = Image.new("RGB", (512, 512), "white")
all_points = []
max_chunks = 10
for chunk_idx in range(max_chunks):
# Prepare inputs with current canvas
messages = [{
"role": "user",
"content": [
{"type": "image", "image": canvas, "min_pixels": 200704, "max_pixels": 401408},
{"type": "text", "text": "draw a circle"},
],
}]
# ... process and predict ...
# Draw on canvas (use BLACK lines to match training!)
draw = ImageDraw.Draw(canvas)
for i in range(1, len(chunk)):
x1, y1 = int(chunk[i-1][0] * 512), int(chunk[i-1][1] * 512)
x2, y2 = int(chunk[i][0] * 512), int(chunk[i][1] * 512)
draw.line([(x1, y1), (x2, y2)], fill='black', width=2)
if chunk[i][2] > 0.5: # Stop signal
break
```
## Training
Trained on Modal H100 for 2 epochs using flow matching loss. See [training code](https://github.com/HusseinLezzaik/Qwen-DiT-Draw).
## Citation
```bibtex
@misc{qwen-dit-draw,
author = {TESS Computer},
title = {Qwen-DiT-Draw: VLM + DiT for Trajectory Prediction},
year = {2025},
url = {https://huggingface.co/TESS-Computer/qwen-dit-draw}
}
```
## Links
- **Code:** [GitHub - Qwen-DiT-Draw](https://github.com/HusseinLezzaik/Qwen-DiT-Draw)
- **Dataset:** [TESS-Computer/quickdraw-circles](https://huggingface.co/datasets/TESS-Computer/quickdraw-circles)
- **Base Model:** [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|