Safetensors
qwen2
olanigan commited on
Commit
90e71a6
·
verified ·
1 Parent(s): 44fda62

Add SFT training script with multi-template tool formatting

Browse files
Files changed (1) hide show
  1. train_sft.py +209 -0
train_sft.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Reference SFT training script for agentic coding.
4
+ Loads a 60/30/10 mix of SWE + tool-use + code-act datasets,
5
+ normalizes to unified message format with multi-template tool formats.
6
+
7
+ Usage:
8
+ python train_sft.py \
9
+ --model nvidia/Nemotron-Terminal-8B \
10
+ --output_dir ./nexus-coder-sft
11
+ """
12
+
13
+ import argparse
14
+ import random
15
+ import json
16
+ from datasets import load_dataset, concatenate_datasets, Dataset
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from trl import SFTTrainer, SFTConfig
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Tool template formatters (multi-template trick for generalization)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ def format_openai_json(tool_calls: list) -> str:
25
+ out = ""
26
+ for tc in tool_calls:
27
+ out += f'<tool_call>{{"type": "function", "function": {{"name": "{tc.get("name","")}", "arguments": {json.dumps(tc.get("arguments",""))}}}}}</tool_call>\n'
28
+ return out.strip()
29
+
30
+ def format_xml(tool_calls: list) -> str:
31
+ out = ""
32
+ for tc in tool_calls:
33
+ out += f"<tool_call><name>{tc.get('name','')}</name><arguments>{tc.get('arguments','')}</arguments></tool_call>\n"
34
+ return out.strip()
35
+
36
+ def format_python(tool_calls: list) -> str:
37
+ out = ""
38
+ for tc in tool_calls:
39
+ out += f"{tc.get('name','')}({tc.get('arguments','')})\n"
40
+ return out.strip()
41
+
42
+ def format_typescript(tool_calls: list) -> str:
43
+ out = ""
44
+ for tc in tool_calls:
45
+ out += f"{{ tool: '{tc.get('name','')}', args: {tc.get('arguments','')} }}\n"
46
+ return out.strip()
47
+
48
+ def format_qwen3_xml(tool_calls: list) -> str:
49
+ out = ""
50
+ for tc in tool_calls:
51
+ out += f"<qwen3_coder><tool>{tc.get('name','')}</tool><params>{tc.get('arguments','')}</params></qwen3_coder>\n"
52
+ return out.strip()
53
+
54
+ FORMAT_CHOICES = [format_openai_json, format_xml, format_python, format_typescript, format_qwen3_xml]
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Dataset loaders
58
+ # ---------------------------------------------------------------------------
59
+
60
+ def load_swe_smith(tokenizer) -> Dataset:
61
+ """Load SWE-smith trajectories (tool split, resolved only)."""
62
+ ds = load_dataset("SWE-bench/SWE-smith-trajectories", split="tool")
63
+ ds = ds.filter(lambda x: x.get("resolved", False) is True)
64
+ def normalize(example):
65
+ msgs = example.get("messages", [])
66
+ if isinstance(msgs, str):
67
+ msgs = json.loads(msgs)
68
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
69
+ return {"text": text}
70
+ return ds.map(normalize, remove_columns=ds.column_names)
71
+
72
+ def load_nemotron_agentic(tokenizer) -> Dataset:
73
+ """Load Nemotron-Agentic-v1 interactive_agent + tool_calling."""
74
+ ds_ia = load_dataset("nvidia/Nemotron-Agentic-v1", split="interactive_agent")
75
+ ds_tc = load_dataset("nvidia/Nemotron-Agentic-v1", split="tool_calling")
76
+ ds = concatenate_datasets([ds_ia, ds_tc])
77
+ def normalize(example):
78
+ msgs = example.get("messages", [])
79
+ if isinstance(msgs, str):
80
+ msgs = json.loads(msgs)
81
+ # Apply random template to any assistant tool_calls
82
+ for m in msgs:
83
+ if m.get("role") == "assistant" and m.get("tool_calls"):
84
+ fmt = random.choice(FORMAT_CHOICES)
85
+ m["content"] = fmt(m["tool_calls"])
86
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
87
+ return {"text": text}
88
+ return ds.map(normalize, remove_columns=ds.column_names)
89
+
90
+ def load_code_act(tokenizer) -> Dataset:
91
+ """Load xingyaoww/code-act codeact split."""
92
+ ds = load_dataset("xingyaoww/code-act", split="codeact")
93
+ def normalize(example):
94
+ conv = example.get("conversations", [])
95
+ if isinstance(conv, str):
96
+ conv = json.loads(conv)
97
+ msgs = []
98
+ for c in conv:
99
+ role = "user" if c.get("from") in ("human", "user") else "assistant"
100
+ if c.get("from") == "system":
101
+ role = "system"
102
+ msgs.append({"role": role, "content": c.get("value", "")})
103
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
104
+ return {"text": text}
105
+ return ds.map(normalize, remove_columns=ds.column_names)
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Main
109
+ # ---------------------------------------------------------------------------
110
+
111
+ def main():
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--model", default="nvidia/Nemotron-Terminal-8B")
114
+ parser.add_argument("--output_dir", default="./nexus-coder-sft")
115
+ parser.add_argument("--epochs", type=int, default=3)
116
+ parser.add_argument("--batch_size", type=int, default=2)
117
+ parser.add_argument("--grad_accum", type=int, default=8)
118
+ parser.add_argument("--lr", type=float, default=2e-5)
119
+ parser.add_argument("--max_seq_length", type=int, default=16384)
120
+ parser.add_argument("--hub_model_id", default=None)
121
+ parser.add_argument("--lora", action="store_true", help="Use LoRA if VRAM-constrained")
122
+ parser.add_argument("--lora_r", type=int, default=64)
123
+ parser.add_argument("--lora_alpha", type=int, default=128)
124
+ args = parser.parse_args()
125
+
126
+ print("[1/5] Loading model and tokenizer...")
127
+ model = AutoModelForCausalLM.from_pretrained(
128
+ args.model,
129
+ torch_dtype="bfloat16",
130
+ device_map="auto",
131
+ trust_remote_code=True,
132
+ )
133
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
134
+ if tokenizer.pad_token is None:
135
+ tokenizer.pad_token = tokenizer.eos_token
136
+
137
+ # LoRA setup if requested
138
+ peft_config = None
139
+ if args.lora:
140
+ from peft import LoraConfig, TaskType
141
+ peft_config = LoraConfig(
142
+ r=args.lora_r,
143
+ lora_alpha=args.lora_alpha,
144
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
145
+ task_type=TaskType.CAUSAL_LM,
146
+ lora_dropout=0.05,
147
+ bias="none",
148
+ )
149
+ print(f" Using LoRA (r={args.lora_r}, alpha={args.lora_alpha})")
150
+
151
+ print("[2/5] Loading and mixing datasets...")
152
+ ds_swe = load_swe_smith(tokenizer)
153
+ ds_agentic = load_nemotron_agentic(tokenizer)
154
+ ds_code = load_code_act(tokenizer)
155
+
156
+ # Shuffle and sample to approximate 60/30/10 by token count
157
+ # Simple heuristic: sample proportional to raw example counts
158
+ n_swe = min(len(ds_swe), 10000)
159
+ n_agentic = min(len(ds_agentic), 5000)
160
+ n_code = min(len(ds_code), 2000)
161
+ ds_swe = ds_swe.shuffle(seed=42).select(range(n_swe))
162
+ ds_agentic = ds_agentic.shuffle(seed=42).select(range(n_agentic))
163
+ ds_code = ds_code.shuffle(seed=42).select(range(n_code))
164
+
165
+ mixed = concatenate_datasets([ds_swe, ds_agentic, ds_code])
166
+ mixed = mixed.shuffle(seed=42)
167
+ print(f" Mixed dataset: {len(mixed)} examples")
168
+
169
+ print("[3/5] Applying multi-template normalization...")
170
+ def ensure_text(example):
171
+ return {"text": example.get("text", "")}
172
+ mixed = mixed.map(ensure_text).filter(lambda x: len(x.get("text", "")) > 200)
173
+
174
+ print("[4/5] Configuring SFT trainer...")
175
+ sft_config = SFTConfig(
176
+ output_dir=args.output_dir,
177
+ num_train_epochs=args.epochs,
178
+ per_device_train_batch_size=args.batch_size,
179
+ gradient_accumulation_steps=args.grad_accum,
180
+ learning_rate=args.lr,
181
+ max_seq_length=args.max_seq_length,
182
+ logging_strategy="steps",
183
+ logging_steps=10,
184
+ logging_first_step=True,
185
+ save_strategy="epoch",
186
+ bf16=True,
187
+ gradient_checkpointing=True,
188
+ disable_tqdm=True,
189
+ push_to_hub=args.hub_model_id is not None,
190
+ hub_model_id=args.hub_model_id,
191
+ )
192
+
193
+ trainer = SFTTrainer(
194
+ model=model,
195
+ tokenizer=tokenizer,
196
+ train_dataset=mixed,
197
+ args=sft_config,
198
+ peft_config=peft_config,
199
+ )
200
+
201
+ print("[5/5] Starting SFT training...")
202
+ trainer.train()
203
+ trainer.save_model(args.output_dir)
204
+ tokenizer.save_pretrained(args.output_dir)
205
+ print(f"Done. Model saved to {args.output_dir}")
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()