cross13tasks / code /model /modules /vlm /CosmosReason2.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Haron Wan / CUHK Shenzhen] in [2026].
import torch
import transformers
from typing import Optional, List, Dict
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn.utils.rnn import pad_sequence
from accelerate.logging import get_logger
import torch.nn as nn
logger = get_logger(__name__)
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
ROOT = Path(__file__).parents[1]
SEPARATOR = "-" * 20
PIXELS_PER_TOKEN = 32**2
"""Number of pixels per visual token."""
class _CosmosReason2_Interface(nn.Module):
def __init__(self, config: Optional[dict] = None, **kwargs):
super().__init__()
model_name = "nvidia/Cosmos-Reason2-2B"
self.model = transformers.Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
dtype=torch.bfloat16,
attn_implementation="sdpa"
)
self.processor = transformers.Qwen3VLProcessor.from_pretrained(model_name)
self.config = config
self.model.config.hidden_size = self.model.config.text_config.hidden_size
def forward(self, **kwargs, ) -> CausalLMOutputWithPast:
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = self.model(**kwargs, )
return outputs
def generate(self, **kwargs, ):
with torch.autocast("cuda", dtype=torch.float16):
generation_output = self.model.generate(**kwargs, )
return generation_output
def build_qwenvl_inputs(self, images, instructions, **kwargs):
messages = []
assert len(images) == len(instructions), "Images and instructions must have the same length"
for imgs, instruction in zip(images, instructions):
content = [{"type": "image", "image": img} for img in imgs]
if "CoT_prompt" in self.config.datasets.vla_data: # If using a grounding prompt to task
CoT_prompt = self.config.datasets.vla_data.get("CoT_prompt", "")
prompt = CoT_prompt.replace("{instruction}", instruction)
else:
prompt = instruction
content.append({"type": "text", "text": prompt})
msg = [{"role": "user", "content": content}]
messages.append(msg)
# Process inputs
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
padding=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
# fps=4,
)
return inputs.to(self.model.device)
if __name__ == "__main__":
from omegaconf import OmegaConf
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="/mnt/workspace/users/wanhanwen/JoyRA/examples/Robocasa_tabletop/train_files/starvla_cotrain_robocasa_gr1.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
cfg = OmegaConf.load(args.config_yaml)
cfg.framework.qwenvl.base_vlm = "path/to/Cosmos-Reason2-2B"
cfg.framework.qwenvl.attn_implementation = "sdpa"
qwen_vl = _CosmosReason2_Interface(cfg)
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": f"path/to/sample.png",
},
{"type": "text", "text": "What is the robot most likely to do?"},
],
},
]
# Process inputs
inputs = qwen_vl.processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
# Run inference
generated_ids = qwen_vl.model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids, strict=False)
]
output_text = qwen_vl.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
print(SEPARATOR)
print(output_text[0])
print(SEPARATOR)
# print(f"last_hidden: {last_hidden}")