davidsmts commited on
Commit
e3524fe
·
verified ·
1 Parent(s): 72eeda4

Upload train_sft_qwen25_05b_uv.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_qwen25_05b_uv.py +185 -0
train_sft_qwen25_05b_uv.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate>=0.30.0",
4
+ # "datasets>=2.19.0",
5
+ # "huggingface_hub>=0.24.0",
6
+ # "peft>=0.10.0",
7
+ # "requests>=2.31.0",
8
+ # "torch>=2.2.0",
9
+ # "trackio",
10
+ # "transformers>=4.44.0",
11
+ # "trl>=0.12.0",
12
+ # ]
13
+ # ///
14
+
15
+ import json
16
+ import os
17
+ import time
18
+ from datetime import datetime, timezone
19
+
20
+ import requests
21
+ import torch
22
+ from datasets import load_dataset
23
+ from peft import LoraConfig
24
+ from transformers import AutoTokenizer
25
+ from trl import SFTConfig, SFTTrainer
26
+
27
+ import trackio # noqa: F401 (used via `report_to="trackio"`)
28
+
29
+
30
+ CENTRAL_LOG_ENDPOINT = os.getenv(
31
+ "CENTRAL_LOG_ENDPOINT", "https://agenskill.onrender.com/training-logs"
32
+ )
33
+
34
+
35
+ def _utc_now_iso() -> str:
36
+ return datetime.now(timezone.utc).isoformat()
37
+
38
+
39
+ def _post_central_log(payload: dict) -> None:
40
+ try:
41
+ requests.post(CENTRAL_LOG_ENDPOINT, json=payload, timeout=15)
42
+ except Exception as exc:
43
+ print(f"[central-log] failed: {exc}")
44
+
45
+
46
+ def main() -> None:
47
+ start_ts = time.time()
48
+
49
+ model_name = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-0.5B")
50
+ dataset_name = os.getenv("DATASET_NAME", "trl-lib/Capybara")
51
+ dataset_split = os.getenv("DATASET_SPLIT", "train")
52
+
53
+ hub_model_id = os.getenv("HUB_MODEL_ID")
54
+ if not hub_model_id:
55
+ raise SystemExit(
56
+ "Missing HUB_MODEL_ID (e.g. 'username/qwen25-05b-sft-test'). "
57
+ "This must be set because Jobs storage is ephemeral."
58
+ )
59
+
60
+ max_samples = int(os.getenv("MAX_SAMPLES", "200"))
61
+ max_steps = int(os.getenv("MAX_STEPS", "100"))
62
+ max_length = int(os.getenv("MAX_LENGTH", "512"))
63
+
64
+ trackio_project = os.getenv("TRACKIO_PROJECT", "hf-jobs-sft")
65
+ run_name = os.getenv(
66
+ "RUN_NAME",
67
+ f"qwen25-05b-sft-{datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')}",
68
+ )
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
71
+
72
+ ds = load_dataset(dataset_name, split=dataset_split)
73
+ if max_samples > 0:
74
+ ds = ds.select(range(min(len(ds), max_samples)))
75
+
76
+ split = ds.train_test_split(test_size=0.1, seed=42)
77
+
78
+ def formatting_func(example):
79
+ return tokenizer.apply_chat_template(
80
+ example["messages"],
81
+ tokenize=False,
82
+ add_generation_prompt=False,
83
+ )
84
+
85
+ peft_config = LoraConfig(
86
+ r=16,
87
+ lora_alpha=32,
88
+ lora_dropout=0.05,
89
+ bias="none",
90
+ task_type="CAUSAL_LM",
91
+ target_modules=[
92
+ "q_proj",
93
+ "k_proj",
94
+ "v_proj",
95
+ "o_proj",
96
+ "gate_proj",
97
+ "up_proj",
98
+ "down_proj",
99
+ ],
100
+ )
101
+
102
+ args = SFTConfig(
103
+ output_dir="outputs",
104
+ max_length=max_length,
105
+ learning_rate=2e-4,
106
+ warmup_ratio=0.03,
107
+ max_steps=max_steps,
108
+ per_device_train_batch_size=2,
109
+ per_device_eval_batch_size=2,
110
+ gradient_accumulation_steps=8,
111
+ gradient_checkpointing=True,
112
+ fp16=not torch.cuda.is_bf16_supported(),
113
+ bf16=torch.cuda.is_bf16_supported(),
114
+ logging_steps=5,
115
+ eval_strategy="steps",
116
+ eval_steps=25,
117
+ save_strategy="steps",
118
+ save_steps=50,
119
+ save_total_limit=2,
120
+ report_to="trackio",
121
+ project=trackio_project,
122
+ run_name=run_name,
123
+ push_to_hub=True,
124
+ hub_model_id=hub_model_id,
125
+ hub_strategy="end",
126
+ dataset_num_proc=2,
127
+ )
128
+
129
+ trainer = SFTTrainer(
130
+ model=model_name,
131
+ tokenizer=tokenizer,
132
+ train_dataset=split["train"],
133
+ eval_dataset=split["test"],
134
+ args=args,
135
+ formatting_func=formatting_func,
136
+ peft_config=peft_config,
137
+ )
138
+
139
+ _post_central_log(
140
+ {
141
+ "event": "start",
142
+ "timestamp": _utc_now_iso(),
143
+ "hub_model_id": hub_model_id,
144
+ "model_name": model_name,
145
+ "dataset_name": dataset_name,
146
+ "dataset_split": dataset_split,
147
+ "max_samples": max_samples,
148
+ "max_steps": max_steps,
149
+ "max_length": max_length,
150
+ "trackio_project": trackio_project,
151
+ "run_name": run_name,
152
+ }
153
+ )
154
+
155
+ train_result = trainer.train()
156
+ eval_metrics = trainer.evaluate()
157
+
158
+ trainer.push_to_hub()
159
+
160
+ end_ts = time.time()
161
+ payload = {
162
+ "event": "finish",
163
+ "timestamp": _utc_now_iso(),
164
+ "duration_seconds": round(end_ts - start_ts, 3),
165
+ "hub_model_id": hub_model_id,
166
+ "model_name": model_name,
167
+ "dataset_name": dataset_name,
168
+ "dataset_split": dataset_split,
169
+ "max_samples": max_samples,
170
+ "max_steps": max_steps,
171
+ "max_length": max_length,
172
+ "trackio_project": trackio_project,
173
+ "run_name": run_name,
174
+ "train_metrics": getattr(train_result, "metrics", None),
175
+ "eval_metrics": eval_metrics,
176
+ "trainer_log_history_tail": trainer.state.log_history[-50:],
177
+ }
178
+
179
+ _post_central_log(payload)
180
+ print(json.dumps(payload, indent=2, default=str))
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
185
+