""" SFT Training: math | lora | vision=freeze """ import os import json import torch from PIL import Image from torch.utils.data import Dataset from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer, ) from peft import LoraConfig, get_peft_model, TaskType MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct" DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl" OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/lora_math_f" NUM_EPOCHS = 3 LEARNING_RATE = 1e-5 PER_DEVICE_BATCH_SIZE = 1 GRAD_ACCUM_STEPS = 16 MAX_LENGTH = 4096 FREEZE_VISION = True LORA_R = 64 LORA_ALPHA = 128 LORA_DROPOUT = 0.05 # === Dataset and Collator classes imported from template === exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read()) def main(): print("Config: math | lora | vision=freeze") print(f"Data: {DATA_PATH}") print(f"Output: {OUTPUT_DIR}") processor = AutoProcessor.from_pretrained( MODEL_NAME, min_pixels=3136, max_pixels=1204224, ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ) if FREEZE_VISION: for name, param in model.named_parameters(): if 'visual' in name: param.requires_grad = False print("Froze vision encoder") else: print("Vision encoder trainable") lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.enable_input_require_grads() model.print_trainable_parameters() dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH) training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM_STEPS, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=0.03, weight_decay=0.01, bf16=True, logging_steps=10, save_strategy="steps", save_steps=20, save_total_limit=2, dataloader_num_workers=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, remove_unused_columns=False, report_to="none", deepspeed="ds_zero2.json", save_only_model=True, ) collator = VLMDataCollator(processor) trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator) print("\n===== Starting SFT Training =====") trainer.train() print("\n===== Saving final model =====") trainer.save_model(os.path.join(OUTPUT_DIR, "final")) processor.save_pretrained(os.path.join(OUTPUT_DIR, "final")) print("Merging LoRA weights...") merged_model = model.merge_and_unload() merged_output = os.path.join(OUTPUT_DIR, "merged") merged_model.save_pretrained(merged_output) processor.save_pretrained(merged_output) import shutil for fn in ['preprocessor_config.json', 'chat_template.json']: src = os.path.join(MODEL_NAME, fn) if os.path.exists(src): shutil.copy2(src, os.path.join(merged_output, fn)) print(f"Merged model saved to: {merged_output}") print("\n===== SFT Training Complete =====") if __name__ == "__main__": main()