File size: 8,280 Bytes
90e71a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#!/usr/bin/env python3
"""
Reference SFT training script for agentic coding.
Loads a 60/30/10 mix of SWE + tool-use + code-act datasets,
normalizes to unified message format with multi-template tool formats.

Usage:
    python train_sft.py \
        --model nvidia/Nemotron-Terminal-8B \
        --output_dir ./nexus-coder-sft
"""

import argparse
import random
import json
from datasets import load_dataset, concatenate_datasets, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig

# ---------------------------------------------------------------------------
# Tool template formatters (multi-template trick for generalization)
# ---------------------------------------------------------------------------

def format_openai_json(tool_calls: list) -> str:
    out = ""
    for tc in tool_calls:
        out += f'<tool_call>{{"type": "function", "function": {{"name": "{tc.get("name","")}", "arguments": {json.dumps(tc.get("arguments",""))}}}}}</tool_call>\n'
    return out.strip()

def format_xml(tool_calls: list) -> str:
    out = ""
    for tc in tool_calls:
        out += f"<tool_call><name>{tc.get('name','')}</name><arguments>{tc.get('arguments','')}</arguments></tool_call>\n"
    return out.strip()

def format_python(tool_calls: list) -> str:
    out = ""
    for tc in tool_calls:
        out += f"{tc.get('name','')}({tc.get('arguments','')})\n"
    return out.strip()

def format_typescript(tool_calls: list) -> str:
    out = ""
    for tc in tool_calls:
        out += f"{{ tool: '{tc.get('name','')}', args: {tc.get('arguments','')} }}\n"
    return out.strip()

def format_qwen3_xml(tool_calls: list) -> str:
    out = ""
    for tc in tool_calls:
        out += f"<qwen3_coder><tool>{tc.get('name','')}</tool><params>{tc.get('arguments','')}</params></qwen3_coder>\n"
    return out.strip()

FORMAT_CHOICES = [format_openai_json, format_xml, format_python, format_typescript, format_qwen3_xml]

# ---------------------------------------------------------------------------
# Dataset loaders
# ---------------------------------------------------------------------------

def load_swe_smith(tokenizer) -> Dataset:
    """Load SWE-smith trajectories (tool split, resolved only)."""
    ds = load_dataset("SWE-bench/SWE-smith-trajectories", split="tool")
    ds = ds.filter(lambda x: x.get("resolved", False) is True)
    def normalize(example):
        msgs = example.get("messages", [])
        if isinstance(msgs, str):
            msgs = json.loads(msgs)
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
        return {"text": text}
    return ds.map(normalize, remove_columns=ds.column_names)

def load_nemotron_agentic(tokenizer) -> Dataset:
    """Load Nemotron-Agentic-v1 interactive_agent + tool_calling."""
    ds_ia = load_dataset("nvidia/Nemotron-Agentic-v1", split="interactive_agent")
    ds_tc = load_dataset("nvidia/Nemotron-Agentic-v1", split="tool_calling")
    ds = concatenate_datasets([ds_ia, ds_tc])
    def normalize(example):
        msgs = example.get("messages", [])
        if isinstance(msgs, str):
            msgs = json.loads(msgs)
        # Apply random template to any assistant tool_calls
        for m in msgs:
            if m.get("role") == "assistant" and m.get("tool_calls"):
                fmt = random.choice(FORMAT_CHOICES)
                m["content"] = fmt(m["tool_calls"])
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
        return {"text": text}
    return ds.map(normalize, remove_columns=ds.column_names)

def load_code_act(tokenizer) -> Dataset:
    """Load xingyaoww/code-act codeact split."""
    ds = load_dataset("xingyaoww/code-act", split="codeact")
    def normalize(example):
        conv = example.get("conversations", [])
        if isinstance(conv, str):
            conv = json.loads(conv)
        msgs = []
        for c in conv:
            role = "user" if c.get("from") in ("human", "user") else "assistant"
            if c.get("from") == "system":
                role = "system"
            msgs.append({"role": role, "content": c.get("value", "")})
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
        return {"text": text}
    return ds.map(normalize, remove_columns=ds.column_names)

# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="nvidia/Nemotron-Terminal-8B")
    parser.add_argument("--output_dir", default="./nexus-coder-sft")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--grad_accum", type=int, default=8)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--max_seq_length", type=int, default=16384)
    parser.add_argument("--hub_model_id", default=None)
    parser.add_argument("--lora", action="store_true", help="Use LoRA if VRAM-constrained")
    parser.add_argument("--lora_r", type=int, default=64)
    parser.add_argument("--lora_alpha", type=int, default=128)
    args = parser.parse_args()

    print("[1/5] Loading model and tokenizer...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype="bfloat16",
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # LoRA setup if requested
    peft_config = None
    if args.lora:
        from peft import LoraConfig, TaskType
        peft_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            task_type=TaskType.CAUSAL_LM,
            lora_dropout=0.05,
            bias="none",
        )
        print(f"    Using LoRA (r={args.lora_r}, alpha={args.lora_alpha})")

    print("[2/5] Loading and mixing datasets...")
    ds_swe = load_swe_smith(tokenizer)
    ds_agentic = load_nemotron_agentic(tokenizer)
    ds_code = load_code_act(tokenizer)

    # Shuffle and sample to approximate 60/30/10 by token count
    # Simple heuristic: sample proportional to raw example counts
    n_swe = min(len(ds_swe), 10000)
    n_agentic = min(len(ds_agentic), 5000)
    n_code = min(len(ds_code), 2000)
    ds_swe = ds_swe.shuffle(seed=42).select(range(n_swe))
    ds_agentic = ds_agentic.shuffle(seed=42).select(range(n_agentic))
    ds_code = ds_code.shuffle(seed=42).select(range(n_code))

    mixed = concatenate_datasets([ds_swe, ds_agentic, ds_code])
    mixed = mixed.shuffle(seed=42)
    print(f"    Mixed dataset: {len(mixed)} examples")

    print("[3/5] Applying multi-template normalization...")
    def ensure_text(example):
        return {"text": example.get("text", "")}
    mixed = mixed.map(ensure_text).filter(lambda x: len(x.get("text", "")) > 200)

    print("[4/5] Configuring SFT trainer...")
    sft_config = SFTConfig(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr,
        max_seq_length=args.max_seq_length,
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        save_strategy="epoch",
        bf16=True,
        gradient_checkpointing=True,
        disable_tqdm=True,
        push_to_hub=args.hub_model_id is not None,
        hub_model_id=args.hub_model_id,
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=mixed,
        args=sft_config,
        peft_config=peft_config,
    )

    print("[5/5] Starting SFT training...")
    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Done. Model saved to {args.output_dir}")


if __name__ == "__main__":
    main()