rl4phyx-backup / scripts /train_sft_math_lora_freeze.py
YUNTA88's picture
Upload scripts/train_sft_math_lora_freeze.py with huggingface_hub
5d9fc4e verified
"""
SFT Training: math | lora | vision=freeze
"""
import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
TrainingArguments,
Trainer,
)
from peft import LoraConfig, get_peft_model, TaskType
MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct"
DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl"
OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/lora_math_f"
NUM_EPOCHS = 3
LEARNING_RATE = 1e-5
PER_DEVICE_BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 16
MAX_LENGTH = 4096
FREEZE_VISION = True
LORA_R = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05
# === Dataset and Collator classes imported from template ===
exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read())
def main():
print("Config: math | lora | vision=freeze")
print(f"Data: {DATA_PATH}")
print(f"Output: {OUTPUT_DIR}")
processor = AutoProcessor.from_pretrained(
MODEL_NAME, min_pixels=3136, max_pixels=1204224,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
)
if FREEZE_VISION:
for name, param in model.named_parameters():
if 'visual' in name:
param.requires_grad = False
print("Froze vision encoder")
else:
print("Vision encoder trainable")
lora_config = LoraConfig(
r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.print_trainable_parameters()
dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.01,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=20,
save_total_limit=2,
dataloader_num_workers=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
remove_unused_columns=False,
report_to="none",
deepspeed="ds_zero2.json",
save_only_model=True,
)
collator = VLMDataCollator(processor)
trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator)
print("\n===== Starting SFT Training =====")
trainer.train()
print("\n===== Saving final model =====")
trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
processor.save_pretrained(os.path.join(OUTPUT_DIR, "final"))
print("Merging LoRA weights...")
merged_model = model.merge_and_unload()
merged_output = os.path.join(OUTPUT_DIR, "merged")
merged_model.save_pretrained(merged_output)
processor.save_pretrained(merged_output)
import shutil
for fn in ['preprocessor_config.json', 'chat_template.json']:
src = os.path.join(MODEL_NAME, fn)
if os.path.exists(src): shutil.copy2(src, os.path.join(merged_output, fn))
print(f"Merged model saved to: {merged_output}")
print("\n===== SFT Training Complete =====")
if __name__ == "__main__":
main()