| |
| |
| |
|
|
| import torch |
| import transformers |
| from typing import Optional, List, Dict |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from accelerate.logging import get_logger |
|
|
| import torch.nn as nn |
| logger = get_logger(__name__) |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| from pathlib import Path |
|
|
|
|
| ROOT = Path(__file__).parents[1] |
| SEPARATOR = "-" * 20 |
|
|
| PIXELS_PER_TOKEN = 32**2 |
| """Number of pixels per visual token.""" |
|
|
|
|
| class _CosmosReason2_Interface(nn.Module): |
| def __init__(self, config: Optional[dict] = None, **kwargs): |
| super().__init__() |
| model_name = "nvidia/Cosmos-Reason2-2B" |
| self.model = transformers.Qwen3VLForConditionalGeneration.from_pretrained( |
| model_name, |
| dtype=torch.bfloat16, |
| attn_implementation="sdpa" |
| ) |
| self.processor = transformers.Qwen3VLProcessor.from_pretrained(model_name) |
| self.config = config |
|
|
| self.model.config.hidden_size = self.model.config.text_config.hidden_size |
|
|
|
|
| def forward(self, **kwargs, ) -> CausalLMOutputWithPast: |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| outputs = self.model(**kwargs, ) |
| return outputs |
|
|
| def generate(self, **kwargs, ): |
| with torch.autocast("cuda", dtype=torch.float16): |
| generation_output = self.model.generate(**kwargs, ) |
| return generation_output |
|
|
| def build_qwenvl_inputs(self, images, instructions, **kwargs): |
| messages = [] |
| assert len(images) == len(instructions), "Images and instructions must have the same length" |
| for imgs, instruction in zip(images, instructions): |
| content = [{"type": "image", "image": img} for img in imgs] |
|
|
| if "CoT_prompt" in self.config.datasets.vla_data: |
| CoT_prompt = self.config.datasets.vla_data.get("CoT_prompt", "") |
| prompt = CoT_prompt.replace("{instruction}", instruction) |
| else: |
| prompt = instruction |
|
|
| content.append({"type": "text", "text": prompt}) |
| msg = [{"role": "user", "content": content}] |
|
|
| messages.append(msg) |
|
|
| |
| inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| padding=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| |
| ) |
|
|
| return inputs.to(self.model.device) |
|
|
|
|
| if __name__ == "__main__": |
| from omegaconf import OmegaConf |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="/mnt/workspace/users/wanhanwen/JoyRA/examples/Robocasa_tabletop/train_files/starvla_cotrain_robocasa_gr1.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| |
| cfg.framework.qwenvl.base_vlm = "path/to/Cosmos-Reason2-2B" |
| cfg.framework.qwenvl.attn_implementation = "sdpa" |
| qwen_vl = _CosmosReason2_Interface(cfg) |
|
|
| conversation = [ |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": "You are a helpful assistant."}], |
| }, |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image", |
| "image": f"path/to/sample.png", |
| }, |
| {"type": "text", "text": "What is the robot most likely to do?"}, |
| ], |
| }, |
| ] |
|
|
| |
| inputs = qwen_vl.processor.apply_chat_template( |
| conversation, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
|
|
| |
| generated_ids = qwen_vl.model.generate(**inputs, max_new_tokens=4096) |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids) :] |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids, strict=False) |
| ] |
| output_text = qwen_vl.processor.batch_decode( |
| generated_ids_trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| ) |
| print(SEPARATOR) |
| print(output_text[0]) |
| print(SEPARATOR) |
| |
| |