YUNTA88 commited on
Commit
2f23710
·
verified ·
1 Parent(s): 6ee8119

Upload scripts/train_sft_phyx_fullft_unfreeze.py with huggingface_hub

Browse files
scripts/train_sft_phyx_fullft_unfreeze.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SFT Training: phyx | fullft | vision=unfreeze
3
+ """
4
+ import os
5
+ import json
6
+ import torch
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import (
10
+ Qwen2_5_VLForConditionalGeneration,
11
+ AutoProcessor,
12
+ TrainingArguments,
13
+ Trainer,
14
+ )
15
+
16
+
17
+ MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct"
18
+ DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/sft_train_formatted.jsonl"
19
+ OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/fullft_phyx_nf"
20
+
21
+ NUM_EPOCHS = 3
22
+ LEARNING_RATE = 1e-5
23
+ PER_DEVICE_BATCH_SIZE = 1
24
+ GRAD_ACCUM_STEPS = 8
25
+ MAX_LENGTH = 4096
26
+ FREEZE_VISION = False
27
+
28
+
29
+ # === Dataset and Collator classes imported from template ===
30
+ exec(open('/workspace/rl4phyx/RL4Phyx/SFT/_sft_classes.py').read())
31
+
32
+
33
+ def main():
34
+ print("Config: phyx | fullft | vision=unfreeze")
35
+ print(f"Data: {DATA_PATH}")
36
+ print(f"Output: {OUTPUT_DIR}")
37
+
38
+ processor = AutoProcessor.from_pretrained(
39
+ MODEL_NAME, min_pixels=3136, max_pixels=1204224,
40
+ )
41
+
42
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ MODEL_NAME, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
44
+ )
45
+
46
+ if FREEZE_VISION:
47
+ for name, param in model.named_parameters():
48
+ if 'visual' in name:
49
+ param.requires_grad = False
50
+ print("Froze vision encoder")
51
+ else:
52
+ print("Vision encoder trainable")
53
+
54
+ model.enable_input_require_grads()
55
+
56
+ dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH)
57
+
58
+ training_args = TrainingArguments(
59
+ output_dir=OUTPUT_DIR,
60
+ num_train_epochs=NUM_EPOCHS,
61
+ per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
62
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
63
+ learning_rate=LEARNING_RATE,
64
+ lr_scheduler_type="cosine",
65
+ warmup_ratio=0.03,
66
+ weight_decay=0.01,
67
+ bf16=True,
68
+ logging_steps=10,
69
+ save_strategy="steps",
70
+ save_steps=20,
71
+ save_total_limit=2,
72
+ dataloader_num_workers=4,
73
+ gradient_checkpointing=True,
74
+ gradient_checkpointing_kwargs={'use_reentrant': False},
75
+ remove_unused_columns=False,
76
+ report_to="none",
77
+ deepspeed="ds_zero2.json",
78
+ save_only_model=True,
79
+ )
80
+
81
+ collator = VLMDataCollator(processor)
82
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collator)
83
+
84
+ print("\n===== Starting SFT Training =====")
85
+ trainer.train()
86
+
87
+ print("\n===== Saving final model =====")
88
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
89
+ processor.save_pretrained(os.path.join(OUTPUT_DIR, "final"))
90
+
91
+ print("\n===== SFT Training Complete =====")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()