#!/usr/bin/env python3 """ Ultron-Sec SFT — Cybersecurity Instruction Fine-Tuning (Phase 3) Loads the CPT checkpoint from trojan0x/ultron-sec-cpt and fine-tunes on cybersecurity + code instruction data in ChatML format. Data mix: - m-a-p/Code-Feedback (67K, multi-turn code conversations) - glaiveai/glaive-code-assistant-v3 (~136K, code Q&A) - Bouquets/Cybersecurity-LLM-CVE (reformatted as instruction-following) - CyberNative/CyberSecurityEval (security Q&A, 2x upsampled) Training recipe (grounded in Primus/CyberPal-2 papers): - LR: 2.5e-5 with cosine decay - Prompt masking: only train on assistant responses - Weight decay: 0.05 - 3000 steps Usage: python train_sft.py --hub_model_id trojan0x/ultron-sec # Quick test python train_sft.py --max_steps 50 --log_interval 5 """ import os import sys import math import time import json import random import argparse from dataclasses import asdict import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, ConcatDataset from datasets import load_dataset from transformers import AutoTokenizer from huggingface_hub import hf_hub_download, snapshot_download, HfApi # ── Ultron model code ───────────────────────────────────────────── def setup_ultron(): repo_path = snapshot_download("trojan0x/ultron", allow_patterns=["ultron/*.py"]) sys.path.insert(0, repo_path) print(f"Ultron loaded from: {repo_path}") setup_ultron() from ultron.model import Ultron, UltronConfig # =========================================================================== # ChatML formatting # =========================================================================== CHATML_TEMPLATE = { "system": "<|im_start|>system\n{content}<|im_end|>\n", "user": "<|im_start|>user\n{content}<|im_end|>\n", "assistant": "<|im_start|>assistant\n{content}<|im_end|>\n", } DEFAULT_SYSTEM = ( "You are Ultron-Sec, a cybersecurity AI assistant. You provide detailed, " "accurate analysis of security vulnerabilities, exploit techniques, defensive " "strategies, and code. Always explain your reasoning step by step." ) def format_chatml(messages, system_prompt=None): """Convert a list of {role, content} dicts to ChatML string.""" text = "" if system_prompt: text += CHATML_TEMPLATE["system"].format(content=system_prompt) for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role in CHATML_TEMPLATE: text += CHATML_TEMPLATE[role].format(content=content) return text def get_assistant_mask(tokens, tokenizer, messages, system_prompt=None): """Create a mask that is 1 only for assistant response tokens. This implements prompt masking — we only compute loss on assistant outputs. """ # Build the prompt up to each assistant response to find boundaries mask = torch.zeros(len(tokens), dtype=torch.bool) prefix = "" if system_prompt: prefix += CHATML_TEMPLATE["system"].format(content=system_prompt) for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") formatted = CHATML_TEMPLATE.get(role, "").format(content=content) if role == "assistant": # Everything before this assistant message prefix_tokens = tokenizer.encode(prefix) # Full text including this message full_tokens = tokenizer.encode(prefix + formatted) # Mark assistant tokens start = len(prefix_tokens) end = len(full_tokens) if end <= len(mask): mask[start:end] = True prefix += formatted return mask # =========================================================================== # Dataset classes # =========================================================================== class CodeFeedbackDataset(Dataset): """m-a-p/Code-Feedback — multi-turn code conversations (messages format).""" def __init__(self, tokenizer, max_len=1024, limit=None): print("[data] Loading m-a-p/Code-Feedback...") ds = load_dataset("m-a-p/Code-Feedback", split="train") if limit: ds = ds.select(range(min(limit, len(ds)))) self.data = ds self.tokenizer = tokenizer self.max_len = max_len print(f" Loaded {len(ds)} examples") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] messages = row.get("messages", []) if not messages: return self._empty() text = format_chatml(messages, DEFAULT_SYSTEM) tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) if len(tokens) < 10: return self._empty() input_ids = torch.tensor(tokens[:-1], dtype=torch.long) labels = torch.tensor(tokens[1:], dtype=torch.long) # Prompt masking mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) labels[~mask[:len(labels)]] = -100 # ignore non-assistant tokens return {"input_ids": input_ids, "labels": labels} def _empty(self): return { "input_ids": torch.zeros(1, dtype=torch.long), "labels": torch.full((1,), -100, dtype=torch.long), } class GlaiveCodeDataset(Dataset): """glaiveai/glaive-code-assistant-v3 — code Q&A pairs.""" def __init__(self, tokenizer, max_len=1024, limit=None): print("[data] Loading glaiveai/glaive-code-assistant-v3...") ds = load_dataset("glaiveai/glaive-code-assistant-v3", split="train") if limit: ds = ds.select(range(min(limit, len(ds)))) self.data = ds self.tokenizer = tokenizer self.max_len = max_len print(f" Loaded {len(ds)} examples") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] question = row.get("question", "") answer = row.get("answer", "") if not question or not answer: return self._empty() messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": answer}, ] text = format_chatml(messages, DEFAULT_SYSTEM) tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) if len(tokens) < 10: return self._empty() input_ids = torch.tensor(tokens[:-1], dtype=torch.long) labels = torch.tensor(tokens[1:], dtype=torch.long) mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) labels[~mask[:len(labels)]] = -100 return {"input_ids": input_ids, "labels": labels} def _empty(self): return { "input_ids": torch.zeros(1, dtype=torch.long), "labels": torch.full((1,), -100, dtype=torch.long), } class CVEInstructDataset(Dataset): """Bouquets/Cybersecurity-LLM-CVE — CVE entries reformatted as instructions.""" def __init__(self, tokenizer, max_len=1024, limit=None): print("[data] Loading Bouquets/Cybersecurity-LLM-CVE...") ds = load_dataset("Bouquets/Cybersecurity-LLM-CVE", split="train") if limit: ds = ds.select(range(min(limit, len(ds)))) self.data = ds self.tokenizer = tokenizer self.max_len = max_len print(f" Loaded {len(ds)} examples") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] instruction = row.get("instruction", row.get("input", "")) output = row.get("outputs", row.get("output", "")) if not instruction or not output: return self._empty() messages = [ {"role": "user", "content": instruction}, {"role": "assistant", "content": output}, ] text = format_chatml(messages, DEFAULT_SYSTEM) tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) if len(tokens) < 10: return self._empty() input_ids = torch.tensor(tokens[:-1], dtype=torch.long) labels = torch.tensor(tokens[1:], dtype=torch.long) mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) labels[~mask[:len(labels)]] = -100 return {"input_ids": input_ids, "labels": labels} def _empty(self): return { "input_ids": torch.zeros(1, dtype=torch.long), "labels": torch.full((1,), -100, dtype=torch.long), } class CyberSecEvalDataset(Dataset): """CyberNative/CyberSecurityEval — security Q&A.""" def __init__(self, tokenizer, max_len=1024, limit=None): print("[data] Loading CyberNative/CyberSecurityEval...") ds = load_dataset("CyberNative/CyberSecurityEval", split="train") if limit: ds = ds.select(range(min(limit, len(ds)))) self.data = ds self.tokenizer = tokenizer self.max_len = max_len print(f" Loaded {len(ds)} examples") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] # Try common column names question = row.get("question", row.get("prompt", row.get("input", ""))) answer = row.get("answer", row.get("response", row.get("output", ""))) if not question or not answer: return self._empty() messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": answer}, ] text = format_chatml(messages, DEFAULT_SYSTEM) tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) if len(tokens) < 10: return self._empty() input_ids = torch.tensor(tokens[:-1], dtype=torch.long) labels = torch.tensor(tokens[1:], dtype=torch.long) mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) labels[~mask[:len(labels)]] = -100 return {"input_ids": input_ids, "labels": labels} def _empty(self): return { "input_ids": torch.zeros(1, dtype=torch.long), "labels": torch.full((1,), -100, dtype=torch.long), } # =========================================================================== # Collation (variable-length padding) # =========================================================================== def collate_fn(batch, pad_id=0): """Pad batch to max length, with -100 for label padding.""" max_len = max(b["input_ids"].shape[0] for b in batch) input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long) labels = torch.full((len(batch), max_len), -100, dtype=torch.long) for i, b in enumerate(batch): L = b["input_ids"].shape[0] input_ids[i, :L] = b["input_ids"] labels[i, :L] = b["labels"] return {"input_ids": input_ids, "labels": labels} # =========================================================================== # Training # =========================================================================== def get_lr(step, warmup_steps, max_steps, max_lr, min_lr): if step < warmup_steps: return max_lr * (step + 1) / warmup_steps if step >= max_steps: return min_lr progress = (step - warmup_steps) / max(1, max_steps - warmup_steps) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def sample_loop_depth(mu_rec, batch_size): depths = [max(1, min(2*mu_rec, int(torch.distributions.Geometric( probs=1.0/max(mu_rec,1)).sample().item())+1)) for _ in range(batch_size)] return max(1, sum(depths) // len(depths)) def train(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported() dtype = torch.bfloat16 if use_bf16 else torch.float32 print(f"Device: {device} | dtype: {dtype}") # ── Load CPT checkpoint ─────────────────────────────────────── print(f"\nLoading base model from {args.base_model}...") ckpt_path = hf_hub_download(args.base_model, "ultron_sec_cpt_final.pt") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) cfg = UltronConfig(**ckpt["config"]) model = Ultron(cfg) model.load_state_dict(ckpt["model_state_dict"]) model = model.to(device) model.train() print(f" Loaded step {ckpt.get('step')}, loss {ckpt.get('loss', 'N/A')}") print(f" Params: {model.get_num_params(False):,}") print(f" rho(A): {model.get_spectral_radius():.6f}") # ── Tokenizer ───────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token # ── Datasets ────────────────────────────────────────────────── data_limit = args.data_limit if args.data_limit else None datasets_list = [] datasets_list.append(CodeFeedbackDataset(tokenizer, cfg.max_seq_len, limit=data_limit)) datasets_list.append(GlaiveCodeDataset(tokenizer, cfg.max_seq_len, limit=data_limit or 50000)) datasets_list.append(CVEInstructDataset(tokenizer, cfg.max_seq_len, limit=data_limit)) try: # 2x upsample security eval data cyber_ds = CyberSecEvalDataset(tokenizer, cfg.max_seq_len, limit=data_limit) datasets_list.append(cyber_ds) datasets_list.append(cyber_ds) # 2x except Exception as e: print(f" Warning: CyberSecEval failed to load: {e}") combined = ConcatDataset(datasets_list) print(f"\n[data] Combined: {len(combined)} examples") loader = DataLoader( combined, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=lambda b: collate_fn(b, pad_id=tokenizer.eos_token_id), drop_last=True, ) # ── Optimizer (Primus SFT recipe) ───────────────────────────── optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.05, ) # ── Training loop ───────────────────────────────────────────── step = 0 tokens_seen = 0 running_loss = 0.0 t0 = time.time() log_t0 = time.time() epoch = 0 print(f"\nSFT Training for {args.max_steps} steps") print(f" Batch: {args.batch_size} x {args.grad_accum} accum = {args.batch_size * args.grad_accum}") print(f" LR: {args.lr} -> {args.min_lr}") print(f" bf16: {use_bf16}\n") optimizer.zero_grad() while step < args.max_steps: epoch += 1 print(f"--- Epoch {epoch} ---") for batch in loader: if step >= args.max_steps: break input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) lr = get_lr(step, args.warmup_steps, args.max_steps, args.lr, args.min_lr) for g in optimizer.param_groups: g["lr"] = lr n_loops = sample_loop_depth(cfg.max_loop_iters, input_ids.shape[0]) with torch.autocast(device_type="cuda", dtype=dtype, enabled=use_bf16): logits = model(input_ids, n_loops=n_loops) loss = F.cross_entropy( logits.view(-1, cfg.vocab_size), labels.view(-1), ignore_index=-100, ) loss_scaled = loss / args.grad_accum loss_scaled.backward() running_loss += loss.item() tokens_seen += (labels != -100).sum().item() if (step + 1) % args.grad_accum == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() step += 1 if step % args.log_interval == 0: avg = running_loss / args.log_interval ppl = math.exp(min(avg, 20)) rho = model.get_spectral_radius() dt = time.time() - log_t0 print(f"step {step:>5d}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:.1f} | " f"lr {lr:.2e} | rho(A) {rho:.4f} | depth {n_loops} | " f"{tokens_seen:,} tokens | {dt:.1f}s") running_loss = 0.0 log_t0 = time.time() if step % args.save_interval == 0 and step > 0: save_checkpoint(model, cfg, step, tokens_seen, args) # ── Final save ──────────────────────────────────────────────── elapsed = time.time() - t0 print(f"\nSFT complete! {step} steps in {elapsed:.0f}s ({elapsed/3600:.1f}h)") print(f"Final rho(A): {model.get_spectral_radius():.6f}") final = { "step": step, "tokens_seen": tokens_seen, "model_state_dict": model.state_dict(), "config": asdict(cfg), "training": "sft", "base_model": args.base_model, } final_path = "ultron_sec_final.pt" torch.save(final, final_path) print(f"Saved: {final_path}") if args.hub_model_id: try: api = HfApi() api.upload_file( path_or_fileobj=final_path, path_in_repo="ultron_sec_final.pt", repo_id=args.hub_model_id, ) config_path = "config.json" with open(config_path, "w") as f: json.dump(asdict(cfg), f, indent=2, default=str) api.upload_file( path_or_fileobj=config_path, path_in_repo="config.json", repo_id=args.hub_model_id, ) print(f"Pushed to {args.hub_model_id}") except Exception as e: print(f"Push failed: {e}") print("Done!") def save_checkpoint(model, cfg, step, tokens_seen, args): ckpt = { "step": step, "tokens_seen": tokens_seen, "model_state_dict": model.state_dict(), "config": asdict(cfg), } path = f"ultron_sec_sft_step{step}.pt" torch.save(ckpt, path) print(f" Checkpoint: {path}") if args.hub_model_id: try: api = HfApi() api.upload_file( path_or_fileobj=path, path_in_repo=f"checkpoints/{path}", repo_id=args.hub_model_id, ) print(f" Pushed to {args.hub_model_id}") except Exception as e: print(f" Push failed: {e}") if os.path.exists(path): os.remove(path) def main(): parser = argparse.ArgumentParser(description="Ultron-Sec SFT Training") parser.add_argument("--base_model", type=str, default="trojan0x/ultron-sec-cpt", help="CPT model to fine-tune from") parser.add_argument("--hub_model_id", type=str, default=os.environ.get("HUB_MODEL_ID", "trojan0x/ultron-sec"), help="Where to push the SFT model") parser.add_argument("--max_steps", type=int, default=3000) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--grad_accum", type=int, default=8) parser.add_argument("--lr", type=float, default=2.5e-5) parser.add_argument("--min_lr", type=float, default=2.5e-6) parser.add_argument("--warmup_steps", type=int, default=100) parser.add_argument("--log_interval", type=int, default=10) parser.add_argument("--save_interval", type=int, default=1000) parser.add_argument("--data_limit", type=int, default=None, help="Limit examples per dataset (for testing)") args = parser.parse_args() train(args) if __name__ == "__main__": main()