davidsmts commited on
Commit
247caa6
·
verified ·
1 Parent(s): b98d4e6

Upload train_sft_capybara.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_capybara.py +99 -0
train_sft_capybara.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.28.0", "peft>=0.18.1", "datasets", "trackio"]
3
+ # ///
4
+
5
+ """SFT training job that fine-tunes Qwen/Qwen2.5-0.5B on the Capybara instructions."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import datetime
10
+ import os
11
+
12
+ from datasets import load_dataset
13
+ from trl import SFTConfig, SFTTrainer
14
+
15
+
16
+ def to_prompt_completion(example: dict) -> dict:
17
+ """Turn a multi-turn conversation into a single prompt-completion pair."""
18
+ messages = example.get("messages", [])
19
+ prompt = None
20
+ completion = None
21
+ for message in messages:
22
+ role = (message.get("role") or "").lower()
23
+ if role == "user" and prompt is None:
24
+ prompt = message.get("content", "").strip()
25
+ elif role == "assistant" and prompt and completion is None:
26
+ completion = message.get("content", "").strip()
27
+ break
28
+ if prompt and completion:
29
+ return {
30
+ "text": (
31
+ "### Instruction:\n"
32
+ f"{prompt}\n\n"
33
+ "### Response:\n"
34
+ f"{completion}"
35
+ )
36
+ }
37
+ return {"text": None}
38
+
39
+
40
+ def prepare_dataset() -> tuple:
41
+ """Load Capybara, keep only valid prompt-completion examples, and split it."""
42
+ dataset = load_dataset("trl-lib/Capybara", split="train")
43
+ processed = dataset.map(to_prompt_completion, remove_columns=dataset.column_names)
44
+ processed = processed.filter(lambda example: example["text"] is not None)
45
+ split = processed.train_test_split(test_size=0.05, seed=42)
46
+ return split["train"], split["test"]
47
+
48
+
49
+ def main() -> None:
50
+ hf_token = os.environ.get("HF_TOKEN")
51
+ print("HF_TOKEN:", hf_token)
52
+
53
+ print("Building dataset …")
54
+ train_dataset, eval_dataset = prepare_dataset()
55
+ print("Train samples:", len(train_dataset))
56
+ print("Eval samples:", len(eval_dataset))
57
+
58
+ model_name = "Qwen/Qwen2.5-0.5B"
59
+ run_name = f"capybara-sft-{datetime.datetime.utcnow():%Y%m%d-%H%M%S}"
60
+
61
+ config = SFTConfig(
62
+ output_dir="capybara-sft-output",
63
+ dataset_text_field="text",
64
+ report_to="trackio",
65
+ project="capybara-sft",
66
+ run_name=run_name,
67
+ trackio_space_id="trackio",
68
+ eval_strategy="steps",
69
+ eval_steps=200,
70
+ logging_steps=50,
71
+ logging_dir="capybara-sft-output/logs",
72
+ save_strategy="steps",
73
+ save_steps=200,
74
+ save_total_limit=3,
75
+ num_train_epochs=2,
76
+ per_device_train_batch_size=4,
77
+ per_device_eval_batch_size=4,
78
+ gradient_accumulation_steps=2,
79
+ gradient_checkpointing=True,
80
+ learning_rate=2e-5,
81
+ push_to_hub=True,
82
+ hub_model_id="davidsmts/sft-capybara-demo",
83
+ hub_strategy="every_save",
84
+ )
85
+
86
+ trainer = SFTTrainer(
87
+ model=model_name,
88
+ train_dataset=train_dataset,
89
+ eval_dataset=eval_dataset,
90
+ args=config,
91
+ )
92
+
93
+ trainer.train()
94
+ trainer.push_to_hub()
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
99
+