""" SFT Training: math | fullft | vision=freeze """ import os import json import torch from PIL import Image from torch.utils.data import Dataset from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer, ) MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct" DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl" OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/fullft_math_f" NUM_EPOCHS = 3 LEARNING_RATE = 1e-5 PER_DEVICE_BATCH_SIZE = 1 GRAD_ACCUM_STEPS = 16 MAX_LENGTH = 4096 FREEZE_VISION = True # === Dataset and Collator classes imported from template === exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read()) def main(): print("Config: math | fullft | vision=freeze") print(f"Data: {DATA_PATH}") print(f"Output: {OUTPUT_DIR}") processor = AutoProcessor.from_pretrained( MODEL_NAME, min_pixels=3136, max_pixels=1204224, ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ) if FREEZE_VISION: for name, param in model.named_parameters(): if 'visual' in name: param.requires_grad = False print("Froze vision encoder") else: print("Vision encoder trainable") model.enable_input_require_grads() dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH) training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM_STEPS, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=0.03, weight_decay=0.01, bf16=True, logging_steps=10, save_strategy="steps", save_steps=20, save_total_limit=2, dataloader_num_workers=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, remove_unused_columns=False, report_to="none", deepspeed="ds_zero2.json", save_only_model=False, ) collator = VLMDataCollator(processor) trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator) print("\n===== Starting SFT Training =====") trainer.train() print("\n===== Saving final model =====") trainer.save_model(os.path.join(OUTPUT_DIR, "final")) processor.save_pretrained(os.path.join(OUTPUT_DIR, "final")) print("\n===== SFT Training Complete =====") if __name__ == "__main__": main()