YUNTA88 commited on
Commit
5d9fc4e
·
verified ·
1 Parent(s): 7e56b55

Upload scripts/train_sft_math_lora_freeze.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_sft_math_lora_freeze.py +116 -0
scripts/train_sft_math_lora_freeze.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SFT Training: math | lora | vision=freeze
3
+ """
4
+ import os
5
+ import json
6
+ import torch
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import (
10
+ Qwen2_5_VLForConditionalGeneration,
11
+ AutoProcessor,
12
+ TrainingArguments,
13
+ Trainer,
14
+ )
15
+ from peft import LoraConfig, get_peft_model, TaskType
16
+
17
+
18
+ MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct"
19
+ DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl"
20
+ OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/lora_math_f"
21
+
22
+ NUM_EPOCHS = 3
23
+ LEARNING_RATE = 1e-5
24
+ PER_DEVICE_BATCH_SIZE = 1
25
+ GRAD_ACCUM_STEPS = 16
26
+ MAX_LENGTH = 4096
27
+ FREEZE_VISION = True
28
+ LORA_R = 64
29
+ LORA_ALPHA = 128
30
+ LORA_DROPOUT = 0.05
31
+
32
+
33
+ # === Dataset and Collator classes imported from template ===
34
+ exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read())
35
+
36
+
37
+ def main():
38
+ print("Config: math | lora | vision=freeze")
39
+ print(f"Data: {DATA_PATH}")
40
+ print(f"Output: {OUTPUT_DIR}")
41
+
42
+ processor = AutoProcessor.from_pretrained(
43
+ MODEL_NAME, min_pixels=3136, max_pixels=1204224,
44
+ )
45
+
46
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
+ MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
48
+ )
49
+
50
+ if FREEZE_VISION:
51
+ for name, param in model.named_parameters():
52
+ if 'visual' in name:
53
+ param.requires_grad = False
54
+ print("Froze vision encoder")
55
+ else:
56
+ print("Vision encoder trainable")
57
+
58
+ lora_config = LoraConfig(
59
+ r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
60
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
61
+ task_type=TaskType.CAUSAL_LM,
62
+ )
63
+ model = get_peft_model(model, lora_config)
64
+ model.enable_input_require_grads()
65
+ model.print_trainable_parameters()
66
+
67
+ dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH)
68
+
69
+ training_args = TrainingArguments(
70
+ output_dir=OUTPUT_DIR,
71
+ num_train_epochs=NUM_EPOCHS,
72
+ per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
73
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
74
+ learning_rate=LEARNING_RATE,
75
+ lr_scheduler_type="cosine",
76
+ warmup_ratio=0.03,
77
+ weight_decay=0.01,
78
+ bf16=True,
79
+ logging_steps=10,
80
+ save_strategy="steps",
81
+ save_steps=20,
82
+ save_total_limit=2,
83
+ dataloader_num_workers=4,
84
+ gradient_checkpointing=True,
85
+ gradient_checkpointing_kwargs={'use_reentrant': False},
86
+ remove_unused_columns=False,
87
+ report_to="none",
88
+ deepspeed="ds_zero2.json",
89
+ save_only_model=True,
90
+ )
91
+
92
+ collator = VLMDataCollator(processor)
93
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator)
94
+
95
+ print("\n===== Starting SFT Training =====")
96
+ trainer.train()
97
+
98
+ print("\n===== Saving final model =====")
99
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
100
+ processor.save_pretrained(os.path.join(OUTPUT_DIR, "final"))
101
+ print("Merging LoRA weights...")
102
+ merged_model = model.merge_and_unload()
103
+ merged_output = os.path.join(OUTPUT_DIR, "merged")
104
+ merged_model.save_pretrained(merged_output)
105
+ processor.save_pretrained(merged_output)
106
+ import shutil
107
+ for fn in ['preprocessor_config.json', 'chat_template.json']:
108
+ src = os.path.join(MODEL_NAME, fn)
109
+ if os.path.exists(src): shutil.copy2(src, os.path.join(merged_output, fn))
110
+ print(f"Merged model saved to: {merged_output}")
111
+
112
+ print("\n===== SFT Training Complete =====")
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()