wheattoast11 commited on
Commit
a98794c
·
verified ·
1 Parent(s): 8ec7ecf

LFM2.5 training script

Browse files
Files changed (1) hide show
  1. train_lfm.py +83 -0
train_lfm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "datasets",
10
+ # ]
11
+ # ///
12
+
13
+ """
14
+ Agent Zero SFT: LiquidAI/LFM2.5-1.2B-Instruct
15
+ LoRA fine-tuning on agent-zero-sft-v1 dataset.
16
+ """
17
+
18
+ import trackio
19
+ from datasets import load_dataset
20
+ from peft import LoraConfig
21
+ from trl import SFTTrainer, SFTConfig
22
+
23
+ # Load dataset
24
+ print("Loading dataset...")
25
+ train_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/train.jsonl", split="train")
26
+ val_ds = load_dataset("wheattoast11/agent-zero-sft-v1", data_files="data/validation.jsonl", split="train")
27
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
28
+
29
+ config = SFTConfig(
30
+ output_dir="agent-zero-lfm-1.2b-v1",
31
+ push_to_hub=True,
32
+ hub_model_id="wheattoast11/agent-zero-lfm-1.2b-v1",
33
+ hub_strategy="every_save",
34
+ hub_private_repo=True,
35
+
36
+ num_train_epochs=3,
37
+ per_device_train_batch_size=4,
38
+ gradient_accumulation_steps=4,
39
+ learning_rate=2e-4,
40
+ bf16=True,
41
+
42
+ logging_steps=10,
43
+ save_strategy="steps",
44
+ save_steps=100,
45
+ save_total_limit=2,
46
+
47
+ eval_strategy="steps",
48
+ eval_steps=100,
49
+
50
+ warmup_ratio=0.1,
51
+ lr_scheduler_type="cosine",
52
+
53
+ report_to="trackio",
54
+ project="agent-zero-finetune",
55
+ run_name="lfm-1.2b-sft-v1",
56
+ )
57
+
58
+ peft_config = LoraConfig(
59
+ r=16,
60
+ lora_alpha=32,
61
+ lora_dropout=0.05,
62
+ bias="none",
63
+ task_type="CAUSAL_LM",
64
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
65
+ )
66
+
67
+ print("Initializing trainer...")
68
+ trainer = SFTTrainer(
69
+ model="LiquidAI/LFM2.5-1.2B-Instruct",
70
+ train_dataset=train_ds,
71
+ eval_dataset=val_ds,
72
+ args=config,
73
+ peft_config=peft_config,
74
+ )
75
+
76
+ print("Starting training...")
77
+ trainer.train()
78
+
79
+ print("Pushing to Hub...")
80
+ trainer.push_to_hub()
81
+
82
+ trackio.finish()
83
+ print("Done! Model at: https://huggingface.co/wheattoast11/agent-zero-lfm-1.2b-v1")