wheattoast11 commited on
Commit
3520fb7
·
verified ·
1 Parent(s): a98794c

GLM-4.7 training script

Browse files
Files changed (1) hide show
  1. train_glm.py +86 -0
train_glm.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "datasets",
10
+ # ]
11
+ # ///
12
+
13
+ """
14
+ Agent Zero SFT: zai-org/GLM-4.7-Flash (30B MoE)
15
+ LoRA fine-tuning on agent-zero-sft-v1 dataset.
16
+ Router layers frozen - only attention layers trained.
17
+ """
18
+
19
+ import trackio
20
+ from datasets import load_dataset
21
+ from peft import LoraConfig
22
+ from trl import SFTTrainer, SFTConfig
23
+
24
+ # Load dataset
25
+ print("Loading dataset...")
26
+ train_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/train.jsonl", split="train")
27
+ val_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/validation.jsonl", split="train")
28
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
29
+
30
+ config = SFTConfig(
31
+ output_dir="agent-zero-glm-4.7-v1",
32
+ push_to_hub=True,
33
+ hub_model_id="wheattoast11/agent-zero-glm-4.7-v1",
34
+ hub_strategy="every_save",
35
+ hub_private_repo=True,
36
+
37
+ num_train_epochs=2,
38
+ per_device_train_batch_size=1,
39
+ gradient_accumulation_steps=16,
40
+ learning_rate=1e-4,
41
+ bf16=True,
42
+ gradient_checkpointing=True,
43
+
44
+ logging_steps=10,
45
+ save_strategy="steps",
46
+ save_steps=50,
47
+ save_total_limit=2,
48
+
49
+ eval_strategy="steps",
50
+ eval_steps=50,
51
+
52
+ warmup_ratio=0.1,
53
+ lr_scheduler_type="cosine",
54
+
55
+ report_to="trackio",
56
+ project="agent-zero-finetune",
57
+ run_name="glm-4.7-flash-sft-v1",
58
+ )
59
+
60
+ # LoRA targeting attention layers only (router layers frozen)
61
+ peft_config = LoraConfig(
62
+ r=16,
63
+ lora_alpha=32,
64
+ lora_dropout=0.05,
65
+ bias="none",
66
+ task_type="CAUSAL_LM",
67
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
68
+ )
69
+
70
+ print("Initializing trainer...")
71
+ trainer = SFTTrainer(
72
+ model="zai-org/GLM-4.7-Flash",
73
+ train_dataset=train_ds,
74
+ eval_dataset=val_ds,
75
+ args=config,
76
+ peft_config=peft_config,
77
+ )
78
+
79
+ print("Starting training...")
80
+ trainer.train()
81
+
82
+ print("Pushing to Hub...")
83
+ trainer.push_to_hub()
84
+
85
+ trackio.finish()
86
+ print("Done! Model at: https://huggingface.co/wheattoast11/agent-zero-glm-4.7-v1")