""" SFT Training Script for Qwen2.5-VL-3B-Instruct on Physics CoT Data. Uses HuggingFace Transformers Trainer with Qwen2.5-VL processor. Freezes the vision encoder to save memory and prevent catastrophic forgetting. """ import os import json import torch from PIL import Image from torch.utils.data import Dataset from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer, ) from peft import LoraConfig, get_peft_model, TaskType # ===== Configuration ===== MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct" DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl" OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b_math_lora" # Training hyperparameters NUM_EPOCHS = 3 LEARNING_RATE = 1e-4 # LoRA uses higher LR PER_DEVICE_BATCH_SIZE = 1 # Small batch for 40GB A100 with VLM GRAD_ACCUM_STEPS = 8 # Effective batch = 1 * 4 GPUs * 16 = 64 MAX_LENGTH = 4096 # Max total sequence length USE_LORA = True # LoRA saves memory, merge afterwards for RLVR FREEZE_VISION = True # Always freeze vision encoder # LoRA config LORA_R = 64 LORA_ALPHA = 128 LORA_DROPOUT = 0.05 class PhysicsCoTDataset(Dataset): """Dataset for Qwen2.5-VL SFT with physics CoT.""" def __init__(self, data_path, processor, max_length=2048): self.processor = processor self.max_length = max_length with open(data_path, 'r', encoding='utf-8') as f: self.records = [json.loads(line) for line in f] print(f"Loaded {len(self.records)} records from {data_path}") def __len__(self): return len(self.records) def __getitem__(self, idx): record = self.records[idx] messages = record['messages'] # Extract image path from user message user_msg = messages[0] image_path = None text_content = "" for content in user_msg['content']: if content['type'] == 'image': image_path = content['image'].replace('file://', '') elif content['type'] == 'text': text_content = content['text'] # Extract assistant response assistant_msg = messages[1] assistant_text = assistant_msg['content'][0]['text'] # Load image image = Image.open(image_path).convert('RGB') # Build conversation for apply_chat_template conversation = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text_content}, ], }, { "role": "assistant", "content": [ {"type": "text", "text": assistant_text}, ], }, ] # Use processor to create inputs text = self.processor.apply_chat_template( conversation, tokenize=False, add_generation_prompt=False, ) inputs = self.processor( text=[text], images=[image], padding=False, truncation=True, max_length=self.max_length, return_tensors="pt", ) # Squeeze batch dimension input_ids = inputs['input_ids'].squeeze(0) attention_mask = inputs['attention_mask'].squeeze(0) # Create labels: mask user tokens (only train on assistant response) labels = input_ids.clone() # Find the assistant turn start token and mask everything before it # The chat template adds <|im_start|>assistant\n before the response assistant_token_str = "<|im_start|>assistant\n" assistant_token_ids = self.processor.tokenizer.encode( assistant_token_str, add_special_tokens=False ) # Find the position of assistant turn input_ids_list = input_ids.tolist() assistant_start = -1 for i in range(len(input_ids_list) - len(assistant_token_ids) + 1): if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids: assistant_start = i + len(assistant_token_ids) break if assistant_start > 0: labels[:assistant_start] = -100 # Mask user prompt else: # Fallback: mask first 30% as approximate user tokens mask_len = int(len(labels) * 0.3) labels[:mask_len] = -100 # Also mask padding labels[attention_mask == 0] = -100 return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels, 'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None, 'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None, } class VLMDataCollator: """Custom data collator for variable-length VLM inputs.""" def __init__(self, processor): self.processor = processor self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id def __call__(self, features): # Pad input_ids, attention_mask, labels max_len = max(f['input_ids'].size(0) for f in features) input_ids = [] attention_mask = [] labels = [] pixel_values = [] image_grid_thw = [] for f in features: seq_len = f['input_ids'].size(0) pad_len = max_len - seq_len # Right-pad input_ids.append(torch.cat([ f['input_ids'], torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype) ])) attention_mask.append(torch.cat([ f['attention_mask'], torch.zeros(pad_len, dtype=f['attention_mask'].dtype) ])) labels.append(torch.cat([ f['labels'], torch.full((pad_len,), -100, dtype=f['labels'].dtype) ])) if f.get('pixel_values') is not None: pixel_values.append(f['pixel_values']) if f.get('image_grid_thw') is not None: image_grid_thw.append(f['image_grid_thw']) batch = { 'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels), } if pixel_values: batch['pixel_values'] = torch.cat(pixel_values, dim=0) if image_grid_thw: batch['image_grid_thw'] = torch.stack(image_grid_thw) return batch def main(): print(f"Loading model: {MODEL_NAME}") print(f"Data: {DATA_PATH}") print(f"Output: {OUTPUT_DIR}") print(f"LoRA: {USE_LORA}, Freeze Vision: {FREEZE_VISION}") print(f"Epochs: {NUM_EPOCHS}, LR: {LEARNING_RATE}, Batch: {PER_DEVICE_BATCH_SIZE} x {GRAD_ACCUM_STEPS}") # Load processor processor = AutoProcessor.from_pretrained( MODEL_NAME, min_pixels=3136, # 56x56 max_pixels=200704, # ~256 image tokens ) # Load model model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ) # Freeze vision encoder if FREEZE_VISION: for name, param in model.named_parameters(): if 'visual' in name: param.requires_grad = False print("Froze vision encoder parameters") # Apply LoRA if USE_LORA: # Target only language model layers lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.enable_input_require_grads() # Required for gradient_checkpointing + LoRA model.print_trainable_parameters() # Create dataset dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH) # Training arguments training_args = TrainingArguments( output_dir=OUTPUT_DIR, overwrite_output_dir=True, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM_STEPS, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=0.1, weight_decay=0.01, bf16=True, logging_steps=5, save_strategy="epoch", save_total_limit=3, dataloader_num_workers=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, remove_unused_columns=False, report_to="none", ddp_find_unused_parameters=True, # Must be True: frozen vision params not in backward graph ) # Collator collator = VLMDataCollator(processor) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=collator, ) # Train print("\n===== Starting SFT Training =====") trainer.train() # Save final model print("\n===== Saving final model =====") trainer.save_model(os.path.join(OUTPUT_DIR, "final")) if USE_LORA: # Also merge LoRA and save full model for RLVR print("Merging LoRA weights...") merged_model = model.merge_and_unload() merged_output = os.path.join(OUTPUT_DIR, "merged") merged_model.save_pretrained(merged_output) processor.save_pretrained(merged_output) # Copy visual processor config files from original model import shutil model_dir = MODEL_NAME for fname in ['preprocessor_config.json', 'chat_template.json']: src = os.path.join(model_dir, fname) if os.path.exists(src): shutil.copy2(src, os.path.join(merged_output, fname)) print(f'Copied {fname} to merged output') print(f"Merged model saved to: {merged_output}") print("\n===== SFT Training Complete =====") if __name__ == "__main__": main()