PhysicalAI-reason-VLA

A vision-language driving policy fine-tuned from mjf-su/PhysicalAI-base-VLA (itself based on Qwen/Qwen3-VL-4B-Thinking) using supervised fine-tuning with TRL.

This model extends the base waypoint-prediction VLA with structured chain-of-thought reasoning and discrete driving decisions, trained on 10k Gemini-annotated driving scenes for 2 epochs.


Input / Output

Inputs

  • A forward-facing camera image
  • Past ego-vehicle waypoints in the vehicle's relative frame

Output

<think>
{
  "scene": "2–3 sentence static scene description",
  "move_justification": "2–3 sentence causal explanation linking scene to decisions",
}
</think>
<action>
<longitudinal_token><lateral_token>
</action>
<wp>[x.xx,y.yy,t.tttt]</wp>
<wp>[x.xx,y.yy,t.tttt]</wp>
...

The model produces three outputs in sequence: a reasoning trace (<think>), discrete longitudinal and lateral driving decisions (<action>), and future trajectory waypoints (<wp>).


Decision Tokens

Each <action> block contains exactly one longitudinal and one lateral token.

Longitudinal<stop> · <yield> · <follow> · <gap_search> · <pass> · <adapt> · <cruise>

Lateral<turn_left> · <turn_right> · <lc_left> · <lc_right> · <merge> · <nudge_out_left> · <nudge_out_right> · <nudge_in_left> · <nudge_in_right> · <pull_over> · <abort> · <lane_keep>

These are registered as genuine single tokens in the vocabulary (not subword decompositions), enabling efficient probability measurement over the full decision space with a single forward pass.


Training

Base model mjf-su/PhysicalAI-base-VLA
Dataset mjf-su/PhysicalAI-reason-US
Annotation Gemini batch API (chain-of-thought labels on real US driving data)
Samples 10,000
Epochs 2
Method Completion-only SFT via TRL

Quick Start

from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image

model_id = "mjf-su/PhysicalAI-reason-VLA"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto")

image = Image.open("forward_camera.jpg")

past_waypoints = "<wp>[0.00,0.00,0.0000]</wp>\n<wp>[0.51,0.00,0.0001]</wp>\n..."

messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful AI assistant ..."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": f"[PAST-VEHICLE-MOTION]:\n{past_waypoints}"}
        ]
    }
]

prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt], images=[image], return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
print(processor.batch_decode(outputs, skip_special_tokens=True)[0])

Citation

@misc{vonwerra2022trl,
  title        = {{TRL: Transformer Reinforcement Learning}},
  author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching
                  and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul
                  and Quentin Gallou{\'e}dec},
  year         = 2022,
  journal      = {GitHub repository},
  publisher    = {GitHub},
  howpublished = {\url{https://github.com/huggingface/trl}}
}
Downloads last month
3
Safetensors
Model size
4B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for mjf-su/PhysicalAI-base-VLA

Finetuned
(23)
this model
Quantizations
1 model