NTA-Dev commited on
Commit
c0b8ce5
·
verified ·
1 Parent(s): 76196a1

Upload train_manny.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_manny.py +38 -0
train_manny.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "accelerate", "bitsandbytes"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+
9
+ dataset = load_dataset("NTA-Dev/bf-manny-training", split="train")
10
+ print(f"Loaded {len(dataset)} examples for Manny")
11
+
12
+ trainer = SFTTrainer(
13
+ model="Qwen/Qwen2.5-3B-Instruct",
14
+ train_dataset=dataset,
15
+ peft_config=LoraConfig(
16
+ r=32, lora_alpha=64,
17
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
18
+ lora_dropout=0.05,
19
+ ),
20
+ args=SFTConfig(
21
+ output_dir="manny-agent",
22
+ push_to_hub=True,
23
+ hub_model_id="NTA-Dev/bf-manny-agent",
24
+ num_train_epochs=5,
25
+ per_device_train_batch_size=2,
26
+ gradient_accumulation_steps=4,
27
+ learning_rate=2e-4,
28
+ warmup_ratio=0.1,
29
+ logging_steps=5,
30
+ save_strategy="epoch",
31
+ bf16=True,
32
+ gradient_checkpointing=True,
33
+ max_length=2048,
34
+ )
35
+ )
36
+ trainer.train()
37
+ trainer.push_to_hub()
38
+ print("Manny training complete!")