ultron / train_sft.py
trojan0x's picture
Add SFT training script for Phase 3 (cybersecurity instruction fine-tuning)
7d24350 verified
#!/usr/bin/env python3
"""
Ultron-Sec SFT β€” Cybersecurity Instruction Fine-Tuning (Phase 3)
Loads the CPT checkpoint from trojan0x/ultron-sec-cpt and fine-tunes on
cybersecurity + code instruction data in ChatML format.
Data mix:
- m-a-p/Code-Feedback (67K, multi-turn code conversations)
- glaiveai/glaive-code-assistant-v3 (~136K, code Q&A)
- Bouquets/Cybersecurity-LLM-CVE (reformatted as instruction-following)
- CyberNative/CyberSecurityEval (security Q&A, 2x upsampled)
Training recipe (grounded in Primus/CyberPal-2 papers):
- LR: 2.5e-5 with cosine decay
- Prompt masking: only train on assistant responses
- Weight decay: 0.05
- 3000 steps
Usage:
python train_sft.py --hub_model_id trojan0x/ultron-sec
# Quick test
python train_sft.py --max_steps 50 --log_interval 5
"""
import os
import sys
import math
import time
import json
import random
import argparse
from dataclasses import asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from datasets import load_dataset
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download, snapshot_download, HfApi
# ── Ultron model code ─────────────────────────────────────────────
def setup_ultron():
repo_path = snapshot_download("trojan0x/ultron", allow_patterns=["ultron/*.py"])
sys.path.insert(0, repo_path)
print(f"Ultron loaded from: {repo_path}")
setup_ultron()
from ultron.model import Ultron, UltronConfig
# ===========================================================================
# ChatML formatting
# ===========================================================================
CHATML_TEMPLATE = {
"system": "<|im_start|>system\n{content}<|im_end|>\n",
"user": "<|im_start|>user\n{content}<|im_end|>\n",
"assistant": "<|im_start|>assistant\n{content}<|im_end|>\n",
}
DEFAULT_SYSTEM = (
"You are Ultron-Sec, a cybersecurity AI assistant. You provide detailed, "
"accurate analysis of security vulnerabilities, exploit techniques, defensive "
"strategies, and code. Always explain your reasoning step by step."
)
def format_chatml(messages, system_prompt=None):
"""Convert a list of {role, content} dicts to ChatML string."""
text = ""
if system_prompt:
text += CHATML_TEMPLATE["system"].format(content=system_prompt)
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role in CHATML_TEMPLATE:
text += CHATML_TEMPLATE[role].format(content=content)
return text
def get_assistant_mask(tokens, tokenizer, messages, system_prompt=None):
"""Create a mask that is 1 only for assistant response tokens.
This implements prompt masking β€” we only compute loss on assistant outputs.
"""
# Build the prompt up to each assistant response to find boundaries
mask = torch.zeros(len(tokens), dtype=torch.bool)
prefix = ""
if system_prompt:
prefix += CHATML_TEMPLATE["system"].format(content=system_prompt)
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
formatted = CHATML_TEMPLATE.get(role, "").format(content=content)
if role == "assistant":
# Everything before this assistant message
prefix_tokens = tokenizer.encode(prefix)
# Full text including this message
full_tokens = tokenizer.encode(prefix + formatted)
# Mark assistant tokens
start = len(prefix_tokens)
end = len(full_tokens)
if end <= len(mask):
mask[start:end] = True
prefix += formatted
return mask
# ===========================================================================
# Dataset classes
# ===========================================================================
class CodeFeedbackDataset(Dataset):
"""m-a-p/Code-Feedback β€” multi-turn code conversations (messages format)."""
def __init__(self, tokenizer, max_len=1024, limit=None):
print("[data] Loading m-a-p/Code-Feedback...")
ds = load_dataset("m-a-p/Code-Feedback", split="train")
if limit:
ds = ds.select(range(min(limit, len(ds))))
self.data = ds
self.tokenizer = tokenizer
self.max_len = max_len
print(f" Loaded {len(ds)} examples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
messages = row.get("messages", [])
if not messages:
return self._empty()
text = format_chatml(messages, DEFAULT_SYSTEM)
tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True)
if len(tokens) < 10:
return self._empty()
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
labels = torch.tensor(tokens[1:], dtype=torch.long)
# Prompt masking
mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM)
labels[~mask[:len(labels)]] = -100 # ignore non-assistant tokens
return {"input_ids": input_ids, "labels": labels}
def _empty(self):
return {
"input_ids": torch.zeros(1, dtype=torch.long),
"labels": torch.full((1,), -100, dtype=torch.long),
}
class GlaiveCodeDataset(Dataset):
"""glaiveai/glaive-code-assistant-v3 β€” code Q&A pairs."""
def __init__(self, tokenizer, max_len=1024, limit=None):
print("[data] Loading glaiveai/glaive-code-assistant-v3...")
ds = load_dataset("glaiveai/glaive-code-assistant-v3", split="train")
if limit:
ds = ds.select(range(min(limit, len(ds))))
self.data = ds
self.tokenizer = tokenizer
self.max_len = max_len
print(f" Loaded {len(ds)} examples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
question = row.get("question", "")
answer = row.get("answer", "")
if not question or not answer:
return self._empty()
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
text = format_chatml(messages, DEFAULT_SYSTEM)
tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True)
if len(tokens) < 10:
return self._empty()
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
labels = torch.tensor(tokens[1:], dtype=torch.long)
mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM)
labels[~mask[:len(labels)]] = -100
return {"input_ids": input_ids, "labels": labels}
def _empty(self):
return {
"input_ids": torch.zeros(1, dtype=torch.long),
"labels": torch.full((1,), -100, dtype=torch.long),
}
class CVEInstructDataset(Dataset):
"""Bouquets/Cybersecurity-LLM-CVE β€” CVE entries reformatted as instructions."""
def __init__(self, tokenizer, max_len=1024, limit=None):
print("[data] Loading Bouquets/Cybersecurity-LLM-CVE...")
ds = load_dataset("Bouquets/Cybersecurity-LLM-CVE", split="train")
if limit:
ds = ds.select(range(min(limit, len(ds))))
self.data = ds
self.tokenizer = tokenizer
self.max_len = max_len
print(f" Loaded {len(ds)} examples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
instruction = row.get("instruction", row.get("input", ""))
output = row.get("outputs", row.get("output", ""))
if not instruction or not output:
return self._empty()
messages = [
{"role": "user", "content": instruction},
{"role": "assistant", "content": output},
]
text = format_chatml(messages, DEFAULT_SYSTEM)
tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True)
if len(tokens) < 10:
return self._empty()
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
labels = torch.tensor(tokens[1:], dtype=torch.long)
mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM)
labels[~mask[:len(labels)]] = -100
return {"input_ids": input_ids, "labels": labels}
def _empty(self):
return {
"input_ids": torch.zeros(1, dtype=torch.long),
"labels": torch.full((1,), -100, dtype=torch.long),
}
class CyberSecEvalDataset(Dataset):
"""CyberNative/CyberSecurityEval β€” security Q&A."""
def __init__(self, tokenizer, max_len=1024, limit=None):
print("[data] Loading CyberNative/CyberSecurityEval...")
ds = load_dataset("CyberNative/CyberSecurityEval", split="train")
if limit:
ds = ds.select(range(min(limit, len(ds))))
self.data = ds
self.tokenizer = tokenizer
self.max_len = max_len
print(f" Loaded {len(ds)} examples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
# Try common column names
question = row.get("question", row.get("prompt", row.get("input", "")))
answer = row.get("answer", row.get("response", row.get("output", "")))
if not question or not answer:
return self._empty()
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
text = format_chatml(messages, DEFAULT_SYSTEM)
tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True)
if len(tokens) < 10:
return self._empty()
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
labels = torch.tensor(tokens[1:], dtype=torch.long)
mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM)
labels[~mask[:len(labels)]] = -100
return {"input_ids": input_ids, "labels": labels}
def _empty(self):
return {
"input_ids": torch.zeros(1, dtype=torch.long),
"labels": torch.full((1,), -100, dtype=torch.long),
}
# ===========================================================================
# Collation (variable-length padding)
# ===========================================================================
def collate_fn(batch, pad_id=0):
"""Pad batch to max length, with -100 for label padding."""
max_len = max(b["input_ids"].shape[0] for b in batch)
input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long)
labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
for i, b in enumerate(batch):
L = b["input_ids"].shape[0]
input_ids[i, :L] = b["input_ids"]
labels[i, :L] = b["labels"]
return {"input_ids": input_ids, "labels": labels}
# ===========================================================================
# Training
# ===========================================================================
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr):
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
if step >= max_steps:
return min_lr
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def sample_loop_depth(mu_rec, batch_size):
depths = [max(1, min(2*mu_rec, int(torch.distributions.Geometric(
probs=1.0/max(mu_rec,1)).sample().item())+1)) for _ in range(batch_size)]
return max(1, sum(depths) // len(depths))
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float32
print(f"Device: {device} | dtype: {dtype}")
# ── Load CPT checkpoint ───────────────────────────────────────
print(f"\nLoading base model from {args.base_model}...")
ckpt_path = hf_hub_download(args.base_model, "ultron_sec_cpt_final.pt")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = UltronConfig(**ckpt["config"])
model = Ultron(cfg)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.train()
print(f" Loaded step {ckpt.get('step')}, loss {ckpt.get('loss', 'N/A')}")
print(f" Params: {model.get_num_params(False):,}")
print(f" rho(A): {model.get_spectral_radius():.6f}")
# ── Tokenizer ─────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# ── Datasets ──────────────────────────────────────────────────
data_limit = args.data_limit if args.data_limit else None
datasets_list = []
datasets_list.append(CodeFeedbackDataset(tokenizer, cfg.max_seq_len, limit=data_limit))
datasets_list.append(GlaiveCodeDataset(tokenizer, cfg.max_seq_len, limit=data_limit or 50000))
datasets_list.append(CVEInstructDataset(tokenizer, cfg.max_seq_len, limit=data_limit))
try:
# 2x upsample security eval data
cyber_ds = CyberSecEvalDataset(tokenizer, cfg.max_seq_len, limit=data_limit)
datasets_list.append(cyber_ds)
datasets_list.append(cyber_ds) # 2x
except Exception as e:
print(f" Warning: CyberSecEval failed to load: {e}")
combined = ConcatDataset(datasets_list)
print(f"\n[data] Combined: {len(combined)} examples")
loader = DataLoader(
combined,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
collate_fn=lambda b: collate_fn(b, pad_id=tokenizer.eos_token_id),
drop_last=True,
)
# ── Optimizer (Primus SFT recipe) ─────────────────────────────
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.05,
)
# ── Training loop ─────────────────────────────────────────────
step = 0
tokens_seen = 0
running_loss = 0.0
t0 = time.time()
log_t0 = time.time()
epoch = 0
print(f"\nSFT Training for {args.max_steps} steps")
print(f" Batch: {args.batch_size} x {args.grad_accum} accum = {args.batch_size * args.grad_accum}")
print(f" LR: {args.lr} -> {args.min_lr}")
print(f" bf16: {use_bf16}\n")
optimizer.zero_grad()
while step < args.max_steps:
epoch += 1
print(f"--- Epoch {epoch} ---")
for batch in loader:
if step >= args.max_steps:
break
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
lr = get_lr(step, args.warmup_steps, args.max_steps, args.lr, args.min_lr)
for g in optimizer.param_groups:
g["lr"] = lr
n_loops = sample_loop_depth(cfg.max_loop_iters, input_ids.shape[0])
with torch.autocast(device_type="cuda", dtype=dtype, enabled=use_bf16):
logits = model(input_ids, n_loops=n_loops)
loss = F.cross_entropy(
logits.view(-1, cfg.vocab_size),
labels.view(-1),
ignore_index=-100,
)
loss_scaled = loss / args.grad_accum
loss_scaled.backward()
running_loss += loss.item()
tokens_seen += (labels != -100).sum().item()
if (step + 1) % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
step += 1
if step % args.log_interval == 0:
avg = running_loss / args.log_interval
ppl = math.exp(min(avg, 20))
rho = model.get_spectral_radius()
dt = time.time() - log_t0
print(f"step {step:>5d}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:.1f} | "
f"lr {lr:.2e} | rho(A) {rho:.4f} | depth {n_loops} | "
f"{tokens_seen:,} tokens | {dt:.1f}s")
running_loss = 0.0
log_t0 = time.time()
if step % args.save_interval == 0 and step > 0:
save_checkpoint(model, cfg, step, tokens_seen, args)
# ── Final save ────────────────────────────────────────────────
elapsed = time.time() - t0
print(f"\nSFT complete! {step} steps in {elapsed:.0f}s ({elapsed/3600:.1f}h)")
print(f"Final rho(A): {model.get_spectral_radius():.6f}")
final = {
"step": step,
"tokens_seen": tokens_seen,
"model_state_dict": model.state_dict(),
"config": asdict(cfg),
"training": "sft",
"base_model": args.base_model,
}
final_path = "ultron_sec_final.pt"
torch.save(final, final_path)
print(f"Saved: {final_path}")
if args.hub_model_id:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=final_path,
path_in_repo="ultron_sec_final.pt",
repo_id=args.hub_model_id,
)
config_path = "config.json"
with open(config_path, "w") as f:
json.dump(asdict(cfg), f, indent=2, default=str)
api.upload_file(
path_or_fileobj=config_path,
path_in_repo="config.json",
repo_id=args.hub_model_id,
)
print(f"Pushed to {args.hub_model_id}")
except Exception as e:
print(f"Push failed: {e}")
print("Done!")
def save_checkpoint(model, cfg, step, tokens_seen, args):
ckpt = {
"step": step,
"tokens_seen": tokens_seen,
"model_state_dict": model.state_dict(),
"config": asdict(cfg),
}
path = f"ultron_sec_sft_step{step}.pt"
torch.save(ckpt, path)
print(f" Checkpoint: {path}")
if args.hub_model_id:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=path,
path_in_repo=f"checkpoints/{path}",
repo_id=args.hub_model_id,
)
print(f" Pushed to {args.hub_model_id}")
except Exception as e:
print(f" Push failed: {e}")
if os.path.exists(path):
os.remove(path)
def main():
parser = argparse.ArgumentParser(description="Ultron-Sec SFT Training")
parser.add_argument("--base_model", type=str, default="trojan0x/ultron-sec-cpt",
help="CPT model to fine-tune from")
parser.add_argument("--hub_model_id", type=str,
default=os.environ.get("HUB_MODEL_ID", "trojan0x/ultron-sec"),
help="Where to push the SFT model")
parser.add_argument("--max_steps", type=int, default=3000)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--grad_accum", type=int, default=8)
parser.add_argument("--lr", type=float, default=2.5e-5)
parser.add_argument("--min_lr", type=float, default=2.5e-6)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--data_limit", type=int, default=None,
help="Limit examples per dataset (for testing)")
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()