File size: 8,280 Bytes
90e71a6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | #!/usr/bin/env python3
"""
Reference SFT training script for agentic coding.
Loads a 60/30/10 mix of SWE + tool-use + code-act datasets,
normalizes to unified message format with multi-template tool formats.
Usage:
python train_sft.py \
--model nvidia/Nemotron-Terminal-8B \
--output_dir ./nexus-coder-sft
"""
import argparse
import random
import json
from datasets import load_dataset, concatenate_datasets, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
# ---------------------------------------------------------------------------
# Tool template formatters (multi-template trick for generalization)
# ---------------------------------------------------------------------------
def format_openai_json(tool_calls: list) -> str:
out = ""
for tc in tool_calls:
out += f'<tool_call>{{"type": "function", "function": {{"name": "{tc.get("name","")}", "arguments": {json.dumps(tc.get("arguments",""))}}}}}</tool_call>\n'
return out.strip()
def format_xml(tool_calls: list) -> str:
out = ""
for tc in tool_calls:
out += f"<tool_call><name>{tc.get('name','')}</name><arguments>{tc.get('arguments','')}</arguments></tool_call>\n"
return out.strip()
def format_python(tool_calls: list) -> str:
out = ""
for tc in tool_calls:
out += f"{tc.get('name','')}({tc.get('arguments','')})\n"
return out.strip()
def format_typescript(tool_calls: list) -> str:
out = ""
for tc in tool_calls:
out += f"{{ tool: '{tc.get('name','')}', args: {tc.get('arguments','')} }}\n"
return out.strip()
def format_qwen3_xml(tool_calls: list) -> str:
out = ""
for tc in tool_calls:
out += f"<qwen3_coder><tool>{tc.get('name','')}</tool><params>{tc.get('arguments','')}</params></qwen3_coder>\n"
return out.strip()
FORMAT_CHOICES = [format_openai_json, format_xml, format_python, format_typescript, format_qwen3_xml]
# ---------------------------------------------------------------------------
# Dataset loaders
# ---------------------------------------------------------------------------
def load_swe_smith(tokenizer) -> Dataset:
"""Load SWE-smith trajectories (tool split, resolved only)."""
ds = load_dataset("SWE-bench/SWE-smith-trajectories", split="tool")
ds = ds.filter(lambda x: x.get("resolved", False) is True)
def normalize(example):
msgs = example.get("messages", [])
if isinstance(msgs, str):
msgs = json.loads(msgs)
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
return {"text": text}
return ds.map(normalize, remove_columns=ds.column_names)
def load_nemotron_agentic(tokenizer) -> Dataset:
"""Load Nemotron-Agentic-v1 interactive_agent + tool_calling."""
ds_ia = load_dataset("nvidia/Nemotron-Agentic-v1", split="interactive_agent")
ds_tc = load_dataset("nvidia/Nemotron-Agentic-v1", split="tool_calling")
ds = concatenate_datasets([ds_ia, ds_tc])
def normalize(example):
msgs = example.get("messages", [])
if isinstance(msgs, str):
msgs = json.loads(msgs)
# Apply random template to any assistant tool_calls
for m in msgs:
if m.get("role") == "assistant" and m.get("tool_calls"):
fmt = random.choice(FORMAT_CHOICES)
m["content"] = fmt(m["tool_calls"])
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
return {"text": text}
return ds.map(normalize, remove_columns=ds.column_names)
def load_code_act(tokenizer) -> Dataset:
"""Load xingyaoww/code-act codeact split."""
ds = load_dataset("xingyaoww/code-act", split="codeact")
def normalize(example):
conv = example.get("conversations", [])
if isinstance(conv, str):
conv = json.loads(conv)
msgs = []
for c in conv:
role = "user" if c.get("from") in ("human", "user") else "assistant"
if c.get("from") == "system":
role = "system"
msgs.append({"role": role, "content": c.get("value", "")})
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
return {"text": text}
return ds.map(normalize, remove_columns=ds.column_names)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="nvidia/Nemotron-Terminal-8B")
parser.add_argument("--output_dir", default="./nexus-coder-sft")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--grad_accum", type=int, default=8)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--max_seq_length", type=int, default=16384)
parser.add_argument("--hub_model_id", default=None)
parser.add_argument("--lora", action="store_true", help="Use LoRA if VRAM-constrained")
parser.add_argument("--lora_r", type=int, default=64)
parser.add_argument("--lora_alpha", type=int, default=128)
args = parser.parse_args()
print("[1/5] Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype="bfloat16",
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# LoRA setup if requested
peft_config = None
if args.lora:
from peft import LoraConfig, TaskType
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type=TaskType.CAUSAL_LM,
lora_dropout=0.05,
bias="none",
)
print(f" Using LoRA (r={args.lora_r}, alpha={args.lora_alpha})")
print("[2/5] Loading and mixing datasets...")
ds_swe = load_swe_smith(tokenizer)
ds_agentic = load_nemotron_agentic(tokenizer)
ds_code = load_code_act(tokenizer)
# Shuffle and sample to approximate 60/30/10 by token count
# Simple heuristic: sample proportional to raw example counts
n_swe = min(len(ds_swe), 10000)
n_agentic = min(len(ds_agentic), 5000)
n_code = min(len(ds_code), 2000)
ds_swe = ds_swe.shuffle(seed=42).select(range(n_swe))
ds_agentic = ds_agentic.shuffle(seed=42).select(range(n_agentic))
ds_code = ds_code.shuffle(seed=42).select(range(n_code))
mixed = concatenate_datasets([ds_swe, ds_agentic, ds_code])
mixed = mixed.shuffle(seed=42)
print(f" Mixed dataset: {len(mixed)} examples")
print("[3/5] Applying multi-template normalization...")
def ensure_text(example):
return {"text": example.get("text", "")}
mixed = mixed.map(ensure_text).filter(lambda x: len(x.get("text", "")) > 200)
print("[4/5] Configuring SFT trainer...")
sft_config = SFTConfig(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
max_seq_length=args.max_seq_length,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
disable_tqdm=True,
push_to_hub=args.hub_model_id is not None,
hub_model_id=args.hub_model_id,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=mixed,
args=sft_config,
peft_config=peft_config,
)
print("[5/5] Starting SFT training...")
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"Done. Model saved to {args.output_dir}")
if __name__ == "__main__":
main()
|