| import re |
| import random |
| import argparse |
| from dataclasses import dataclass, field |
| from typing import List |
|
|
| import torch |
| import wandb |
| from tqdm import tqdm |
| from PIL import Image |
| from datasets import load_dataset |
| from transformers import ( |
| Qwen2_5_VLForConditionalGeneration, |
| AutoProcessor, |
| BitsAndBytesConfig, |
| ) |
| from qwen_vl_utils import process_vision_info |
| from peft import LoraConfig, get_peft_model |
| from trl import SFTConfig, SFTTrainer |
|
|
|
|
| def extract_question(raw_text: str) -> str: |
| pattern = r"<\|start_header_id\|>user<\|end_header_id\|>\s*(.*?)\s*<\|eot_id\|>" |
| m = re.search(pattern, raw_text, re.DOTALL) |
| return m.group(1).strip() if m else raw_text.strip() |
|
|
| def format_data_spacethinker(sample): |
| system_message = { |
| "role": "system", |
| "content": [ |
| { |
| "type": "text", |
| "text": ( |
| "You are VL-Thinking U+1F914, a helpful assistant with excellent reasoning ability.\n" |
| "A user asks you a question, and you should try to solve it." |
| "You should first think about the reasoning process in the mind and then provides the user with the answer.\n" |
| "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>." |
| ) |
| } |
| ] |
| } |
| formatted = [system_message] |
|
|
| user_msg = {"role": "user", "content": []} |
| question = extract_question(sample.get("input", "")) |
| if question: |
| user_msg["content"].append({"type": "text", "text": question}) |
| images = sample.get("images") or [] |
| if images: |
| user_msg["content"].append({"type": "image", "image": images[0]}) |
| formatted.append(user_msg) |
|
|
| if sample.get("output"): |
| formatted.append({ |
| "role": "assistant", |
| "content": [{"type": "text", "text": sample["output"]}] |
| }) |
| return formatted |
|
|
|
|
| def collate_fn(examples, processor): |
| |
| texts = [processor.apply_chat_template(sample, tokenize=False) for sample in examples] |
| image_batches = [process_vision_info(sample)[0] for sample in examples] |
| batch = processor(text=texts, images=image_batches, return_tensors="pt", padding=True) |
| batch = {k: v.cpu() for k, v in batch.items()} |
|
|
| labels = batch["input_ids"].clone() |
| labels[labels == processor.tokenizer.pad_token_id] = -100 |
|
|
| image_token_ids = ( |
| [151652, 151653, 151655] |
| if hasattr(processor, "image_processor") |
| else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] |
| ) |
| for tid in image_token_ids: |
| labels[labels == tid] = -100 |
|
|
| batch["labels"] = labels |
| return batch |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| model_id: str = "UCSC-VLAA/VLAA-Thinker-Qwen2.5VL-3B" |
| lora_r: int = 128 |
| lora_alpha: int = 256 |
| lora_dropout: float = 0.05 |
| target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"]) |
| num_train_epochs: int = 3 |
| train_batch_size: int = 1 |
| eval_batch_size: int = 1 |
| gradient_accumulation_steps: int = 8 |
| learning_rate: float = 2e-5 |
| warmup_ratio: float = 0.03 |
| output_dir: str = "spaceom" |
| wandb_project: str = "spaceom" |
| wandb_run_name: str = "spaceom" |
|
|
|
|
| def parse_args() -> TrainingConfig: |
| default_cfg = TrainingConfig() |
| parser = argparse.ArgumentParser(description="Train a VL Spacethinker model with LoRA") |
| parser.add_argument("--model_id", default=default_cfg.model_id) |
| parser.add_argument("--lora_r", type=int, default=default_cfg.lora_r) |
| parser.add_argument("--lora_alpha", type=int, default=default_cfg.lora_alpha) |
| parser.add_argument("--lora_dropout", type=float, default=default_cfg.lora_dropout) |
| parser.add_argument( |
| "--target_modules", |
| default=','.join(default_cfg.target_modules), |
| help="Comma-separated list of target modules for LoRA" |
| ) |
| parser.add_argument("--num_train_epochs", type=int, default=default_cfg.num_train_epochs) |
| parser.add_argument("--train_batch_size", type=int, default=default_cfg.train_batch_size) |
| parser.add_argument("--eval_batch_size", type=int, default=default_cfg.eval_batch_size) |
| parser.add_argument( |
| "--gradient_accumulation_steps", type=int, default=default_cfg.gradient_accumulation_steps |
| ) |
| parser.add_argument("--learning_rate", type=float, default=default_cfg.learning_rate) |
| parser.add_argument("--warmup_ratio", type=float, default=default_cfg.warmup_ratio) |
| parser.add_argument("--output_dir", default=default_cfg.output_dir) |
| parser.add_argument("--wandb_project", default=default_cfg.wandb_project) |
| parser.add_argument("--wandb_run_name", default=default_cfg.wandb_run_name) |
|
|
| args = parser.parse_args() |
| return TrainingConfig( |
| model_id=args.model_id, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| target_modules=args.target_modules.split(","), |
| num_train_epochs=args.num_train_epochs, |
| train_batch_size=args.train_batch_size, |
| eval_batch_size=args.eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| output_dir=args.output_dir, |
| wandb_project=args.wandb_project, |
| wandb_run_name=args.wandb_run_name, |
| ) |
|
|
|
|
| def prepare_datasets(cfg: TrainingConfig): |
| print(f"Loading dataset: SpaceThinker") |
| raw_train_spacethinker = load_dataset("remyxai/SpaceThinker", split="train") |
| raw_eval_spacethinker = load_dataset("remyxai/SpaceThinker", split="test") |
|
|
| print(f"Loading dataset: SpaceOm") |
| raw_train_spaceom = load_dataset("remyxai/SpaceOm", split="train") |
| raw_eval_spaceom = load_dataset("remyxai/SpaceOm", split="test") |
|
|
| print(f"Loading dataset: Robo2VLM") |
| raw_train_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="train") |
| raw_eval_robo2vlm = load_dataset("remyxai/Robo2VLM-Reasoning", split="test") |
|
|
| print("Formatting train samples…") |
| train_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_train_spacethinker, desc="Train")] |
| train_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_train_spaceom, desc="Train")] |
| train_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_train_robo2vlm, desc="Train")] |
| print("Formatting eval samples…") |
| eval_ds_spacethinker = [format_data_spacethinker(s) for s in tqdm(raw_eval_spacethinker, desc="Eval")] |
| eval_ds_spaceom = [format_data_spacethinker(s) for s in tqdm(raw_eval_spaceom, desc="Eval")] |
| eval_ds_robo2vlm = [format_data_spacethinker(s) for s in tqdm(raw_eval_robo2vlm, desc="Eval")] |
|
|
| train_ds = train_ds_spacethinker + train_ds_spaceom + train_ds_robo2vlm |
| eval_ds = eval_ds_spacethinker + eval_ds_spaceom + eval_ds_robo2vlm |
| random.shuffle(train_ds) |
| random.shuffle(eval_ds) |
|
|
| return train_ds, eval_ds |
|
|
|
|
| def prepare_model_and_optimizer(cfg: TrainingConfig): |
| bnb = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| cfg.model_id, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| quantization_config=bnb |
| ) |
| processor = AutoProcessor.from_pretrained(cfg.model_id) |
|
|
| peft_cfg = LoraConfig( |
| r=cfg.lora_r, |
| lora_alpha=cfg.lora_alpha, |
| lora_dropout=cfg.lora_dropout, |
| bias="none", |
| target_modules=cfg.target_modules, |
| task_type="CAUSAL_LM", |
| ) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| peft_model = get_peft_model(model, peft_cfg).to(device) |
| peft_model.print_trainable_parameters() |
| return peft_model, processor, peft_cfg |
|
|
|
|
| def main(): |
| cfg = parse_args() |
| train_ds, eval_ds = prepare_datasets(cfg) |
| model, processor, peft_cfg = prepare_model_and_optimizer(cfg) |
|
|
| sft_args = SFTConfig( |
| output_dir=cfg.output_dir, |
| num_train_epochs=cfg.num_train_epochs, |
| per_device_train_batch_size=cfg.train_batch_size, |
| per_device_eval_batch_size=cfg.eval_batch_size, |
| gradient_accumulation_steps=cfg.gradient_accumulation_steps, |
| gradient_checkpointing=True, |
| optim="adamw_torch_fused", |
| learning_rate=cfg.learning_rate, |
| lr_scheduler_type="constant", |
| logging_steps=10, |
| eval_steps=10, |
| eval_strategy="steps", |
| save_strategy="steps", |
| save_steps=20, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| load_best_model_at_end=True, |
| bf16=True, |
| tf32=True, |
| max_grad_norm=0.3, |
| warmup_ratio=cfg.warmup_ratio, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| push_to_hub=True, |
| report_to="wandb", |
| dataset_kwargs={"skip_prepare_dataset": True}, |
| ) |
| sft_args.remove_unused_columns = False |
|
|
| wandb.init( |
| project=cfg.wandb_project, |
| name=cfg.wandb_run_name, |
| config=sft_args, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=sft_args, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| data_collator=lambda ex: collate_fn(ex, processor), |
| peft_config=peft_cfg, |
| tokenizer=processor.tokenizer, |
| ) |
|
|
| trainer.train() |
| trainer.save_model(cfg.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |