| """ |
| SFT Training: math | fullft | vision=freeze |
| """ |
| import os |
| import json |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from transformers import ( |
| Qwen2_5_VLForConditionalGeneration, |
| AutoProcessor, |
| TrainingArguments, |
| Trainer, |
| ) |
|
|
|
|
| MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct" |
| DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl" |
| OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/fullft_math_f" |
|
|
| NUM_EPOCHS = 3 |
| LEARNING_RATE = 1e-5 |
| PER_DEVICE_BATCH_SIZE = 1 |
| GRAD_ACCUM_STEPS = 16 |
| MAX_LENGTH = 4096 |
| FREEZE_VISION = True |
|
|
|
|
| |
| exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read()) |
|
|
|
|
| def main(): |
| print("Config: math | fullft | vision=freeze") |
| print(f"Data: {DATA_PATH}") |
| print(f"Output: {OUTPUT_DIR}") |
|
|
| processor = AutoProcessor.from_pretrained( |
| MODEL_NAME, min_pixels=3136, max_pixels=1204224, |
| ) |
|
|
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa", |
| ) |
|
|
| if FREEZE_VISION: |
| for name, param in model.named_parameters(): |
| if 'visual' in name: |
| param.requires_grad = False |
| print("Froze vision encoder") |
| else: |
| print("Vision encoder trainable") |
|
|
| model.enable_input_require_grads() |
|
|
| dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH) |
|
|
| training_args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| num_train_epochs=NUM_EPOCHS, |
| per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, |
| gradient_accumulation_steps=GRAD_ACCUM_STEPS, |
| learning_rate=LEARNING_RATE, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.03, |
| weight_decay=0.01, |
| bf16=True, |
| logging_steps=10, |
| save_strategy="steps", |
| save_steps=20, |
| save_total_limit=2, |
| dataloader_num_workers=4, |
| gradient_checkpointing=True, |
| gradient_checkpointing_kwargs={'use_reentrant': False}, |
| remove_unused_columns=False, |
| report_to="none", |
| deepspeed="ds_zero2.json", |
| save_only_model=False, |
| ) |
|
|
| collator = VLMDataCollator(processor) |
| trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator) |
|
|
| print("\n===== Starting SFT Training =====") |
| trainer.train() |
|
|
| print("\n===== Saving final model =====") |
| trainer.save_model(os.path.join(OUTPUT_DIR, "final")) |
| processor.save_pretrained(os.path.join(OUTPUT_DIR, "final")) |
|
|
| print("\n===== SFT Training Complete =====") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|