| |
| """ |
| Ultron-Sec SFT β Cybersecurity Instruction Fine-Tuning (Phase 3) |
| |
| Loads the CPT checkpoint from trojan0x/ultron-sec-cpt and fine-tunes on |
| cybersecurity + code instruction data in ChatML format. |
| |
| Data mix: |
| - m-a-p/Code-Feedback (67K, multi-turn code conversations) |
| - glaiveai/glaive-code-assistant-v3 (~136K, code Q&A) |
| - Bouquets/Cybersecurity-LLM-CVE (reformatted as instruction-following) |
| - CyberNative/CyberSecurityEval (security Q&A, 2x upsampled) |
| |
| Training recipe (grounded in Primus/CyberPal-2 papers): |
| - LR: 2.5e-5 with cosine decay |
| - Prompt masking: only train on assistant responses |
| - Weight decay: 0.05 |
| - 3000 steps |
| |
| Usage: |
| python train_sft.py --hub_model_id trojan0x/ultron-sec |
| |
| # Quick test |
| python train_sft.py --max_steps 50 --log_interval 5 |
| """ |
|
|
| import os |
| import sys |
| import math |
| import time |
| import json |
| import random |
| import argparse |
| from dataclasses import asdict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| from huggingface_hub import hf_hub_download, snapshot_download, HfApi |
|
|
| |
| def setup_ultron(): |
| repo_path = snapshot_download("trojan0x/ultron", allow_patterns=["ultron/*.py"]) |
| sys.path.insert(0, repo_path) |
| print(f"Ultron loaded from: {repo_path}") |
|
|
| setup_ultron() |
| from ultron.model import Ultron, UltronConfig |
|
|
|
|
| |
| |
| |
|
|
| CHATML_TEMPLATE = { |
| "system": "<|im_start|>system\n{content}<|im_end|>\n", |
| "user": "<|im_start|>user\n{content}<|im_end|>\n", |
| "assistant": "<|im_start|>assistant\n{content}<|im_end|>\n", |
| } |
|
|
| DEFAULT_SYSTEM = ( |
| "You are Ultron-Sec, a cybersecurity AI assistant. You provide detailed, " |
| "accurate analysis of security vulnerabilities, exploit techniques, defensive " |
| "strategies, and code. Always explain your reasoning step by step." |
| ) |
|
|
|
|
| def format_chatml(messages, system_prompt=None): |
| """Convert a list of {role, content} dicts to ChatML string.""" |
| text = "" |
| if system_prompt: |
| text += CHATML_TEMPLATE["system"].format(content=system_prompt) |
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| if role in CHATML_TEMPLATE: |
| text += CHATML_TEMPLATE[role].format(content=content) |
| return text |
|
|
|
|
| def get_assistant_mask(tokens, tokenizer, messages, system_prompt=None): |
| """Create a mask that is 1 only for assistant response tokens. |
| This implements prompt masking β we only compute loss on assistant outputs. |
| """ |
| |
| mask = torch.zeros(len(tokens), dtype=torch.bool) |
|
|
| prefix = "" |
| if system_prompt: |
| prefix += CHATML_TEMPLATE["system"].format(content=system_prompt) |
|
|
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| formatted = CHATML_TEMPLATE.get(role, "").format(content=content) |
|
|
| if role == "assistant": |
| |
| prefix_tokens = tokenizer.encode(prefix) |
| |
| full_tokens = tokenizer.encode(prefix + formatted) |
| |
| start = len(prefix_tokens) |
| end = len(full_tokens) |
| if end <= len(mask): |
| mask[start:end] = True |
|
|
| prefix += formatted |
|
|
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class CodeFeedbackDataset(Dataset): |
| """m-a-p/Code-Feedback β multi-turn code conversations (messages format).""" |
|
|
| def __init__(self, tokenizer, max_len=1024, limit=None): |
| print("[data] Loading m-a-p/Code-Feedback...") |
| ds = load_dataset("m-a-p/Code-Feedback", split="train") |
| if limit: |
| ds = ds.select(range(min(limit, len(ds)))) |
| self.data = ds |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| print(f" Loaded {len(ds)} examples") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| messages = row.get("messages", []) |
| if not messages: |
| return self._empty() |
|
|
| text = format_chatml(messages, DEFAULT_SYSTEM) |
| tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) |
|
|
| if len(tokens) < 10: |
| return self._empty() |
|
|
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| labels = torch.tensor(tokens[1:], dtype=torch.long) |
|
|
| |
| mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) |
| labels[~mask[:len(labels)]] = -100 |
|
|
| return {"input_ids": input_ids, "labels": labels} |
|
|
| def _empty(self): |
| return { |
| "input_ids": torch.zeros(1, dtype=torch.long), |
| "labels": torch.full((1,), -100, dtype=torch.long), |
| } |
|
|
|
|
| class GlaiveCodeDataset(Dataset): |
| """glaiveai/glaive-code-assistant-v3 β code Q&A pairs.""" |
|
|
| def __init__(self, tokenizer, max_len=1024, limit=None): |
| print("[data] Loading glaiveai/glaive-code-assistant-v3...") |
| ds = load_dataset("glaiveai/glaive-code-assistant-v3", split="train") |
| if limit: |
| ds = ds.select(range(min(limit, len(ds)))) |
| self.data = ds |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| print(f" Loaded {len(ds)} examples") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| question = row.get("question", "") |
| answer = row.get("answer", "") |
| if not question or not answer: |
| return self._empty() |
|
|
| messages = [ |
| {"role": "user", "content": question}, |
| {"role": "assistant", "content": answer}, |
| ] |
| text = format_chatml(messages, DEFAULT_SYSTEM) |
| tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) |
|
|
| if len(tokens) < 10: |
| return self._empty() |
|
|
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| labels = torch.tensor(tokens[1:], dtype=torch.long) |
|
|
| mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) |
| labels[~mask[:len(labels)]] = -100 |
|
|
| return {"input_ids": input_ids, "labels": labels} |
|
|
| def _empty(self): |
| return { |
| "input_ids": torch.zeros(1, dtype=torch.long), |
| "labels": torch.full((1,), -100, dtype=torch.long), |
| } |
|
|
|
|
| class CVEInstructDataset(Dataset): |
| """Bouquets/Cybersecurity-LLM-CVE β CVE entries reformatted as instructions.""" |
|
|
| def __init__(self, tokenizer, max_len=1024, limit=None): |
| print("[data] Loading Bouquets/Cybersecurity-LLM-CVE...") |
| ds = load_dataset("Bouquets/Cybersecurity-LLM-CVE", split="train") |
| if limit: |
| ds = ds.select(range(min(limit, len(ds)))) |
| self.data = ds |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| print(f" Loaded {len(ds)} examples") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| instruction = row.get("instruction", row.get("input", "")) |
| output = row.get("outputs", row.get("output", "")) |
| if not instruction or not output: |
| return self._empty() |
|
|
| messages = [ |
| {"role": "user", "content": instruction}, |
| {"role": "assistant", "content": output}, |
| ] |
| text = format_chatml(messages, DEFAULT_SYSTEM) |
| tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) |
|
|
| if len(tokens) < 10: |
| return self._empty() |
|
|
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| labels = torch.tensor(tokens[1:], dtype=torch.long) |
|
|
| mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) |
| labels[~mask[:len(labels)]] = -100 |
|
|
| return {"input_ids": input_ids, "labels": labels} |
|
|
| def _empty(self): |
| return { |
| "input_ids": torch.zeros(1, dtype=torch.long), |
| "labels": torch.full((1,), -100, dtype=torch.long), |
| } |
|
|
|
|
| class CyberSecEvalDataset(Dataset): |
| """CyberNative/CyberSecurityEval β security Q&A.""" |
|
|
| def __init__(self, tokenizer, max_len=1024, limit=None): |
| print("[data] Loading CyberNative/CyberSecurityEval...") |
| ds = load_dataset("CyberNative/CyberSecurityEval", split="train") |
| if limit: |
| ds = ds.select(range(min(limit, len(ds)))) |
| self.data = ds |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| print(f" Loaded {len(ds)} examples") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| |
| question = row.get("question", row.get("prompt", row.get("input", ""))) |
| answer = row.get("answer", row.get("response", row.get("output", ""))) |
| if not question or not answer: |
| return self._empty() |
|
|
| messages = [ |
| {"role": "user", "content": question}, |
| {"role": "assistant", "content": answer}, |
| ] |
| text = format_chatml(messages, DEFAULT_SYSTEM) |
| tokens = self.tokenizer.encode(text, max_length=self.max_len + 1, truncation=True) |
|
|
| if len(tokens) < 10: |
| return self._empty() |
|
|
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| labels = torch.tensor(tokens[1:], dtype=torch.long) |
|
|
| mask = get_assistant_mask(tokens[:-1], self.tokenizer, messages, DEFAULT_SYSTEM) |
| labels[~mask[:len(labels)]] = -100 |
|
|
| return {"input_ids": input_ids, "labels": labels} |
|
|
| def _empty(self): |
| return { |
| "input_ids": torch.zeros(1, dtype=torch.long), |
| "labels": torch.full((1,), -100, dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def collate_fn(batch, pad_id=0): |
| """Pad batch to max length, with -100 for label padding.""" |
| max_len = max(b["input_ids"].shape[0] for b in batch) |
| input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long) |
| labels = torch.full((len(batch), max_len), -100, dtype=torch.long) |
|
|
| for i, b in enumerate(batch): |
| L = b["input_ids"].shape[0] |
| input_ids[i, :L] = b["input_ids"] |
| labels[i, :L] = b["labels"] |
|
|
| return {"input_ids": input_ids, "labels": labels} |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step, warmup_steps, max_steps, max_lr, min_lr): |
| if step < warmup_steps: |
| return max_lr * (step + 1) / warmup_steps |
| if step >= max_steps: |
| return min_lr |
| progress = (step - warmup_steps) / max(1, max_steps - warmup_steps) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def sample_loop_depth(mu_rec, batch_size): |
| depths = [max(1, min(2*mu_rec, int(torch.distributions.Geometric( |
| probs=1.0/max(mu_rec,1)).sample().item())+1)) for _ in range(batch_size)] |
| return max(1, sum(depths) // len(depths)) |
|
|
|
|
| def train(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported() |
| dtype = torch.bfloat16 if use_bf16 else torch.float32 |
| print(f"Device: {device} | dtype: {dtype}") |
|
|
| |
| print(f"\nLoading base model from {args.base_model}...") |
| ckpt_path = hf_hub_download(args.base_model, "ultron_sec_cpt_final.pt") |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| cfg = UltronConfig(**ckpt["config"]) |
| model = Ultron(cfg) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model = model.to(device) |
| model.train() |
|
|
| print(f" Loaded step {ckpt.get('step')}, loss {ckpt.get('loss', 'N/A')}") |
| print(f" Params: {model.get_num_params(False):,}") |
| print(f" rho(A): {model.get_spectral_radius():.6f}") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| data_limit = args.data_limit if args.data_limit else None |
| datasets_list = [] |
|
|
| datasets_list.append(CodeFeedbackDataset(tokenizer, cfg.max_seq_len, limit=data_limit)) |
| datasets_list.append(GlaiveCodeDataset(tokenizer, cfg.max_seq_len, limit=data_limit or 50000)) |
| datasets_list.append(CVEInstructDataset(tokenizer, cfg.max_seq_len, limit=data_limit)) |
|
|
| try: |
| |
| cyber_ds = CyberSecEvalDataset(tokenizer, cfg.max_seq_len, limit=data_limit) |
| datasets_list.append(cyber_ds) |
| datasets_list.append(cyber_ds) |
| except Exception as e: |
| print(f" Warning: CyberSecEval failed to load: {e}") |
|
|
| combined = ConcatDataset(datasets_list) |
| print(f"\n[data] Combined: {len(combined)} examples") |
|
|
| loader = DataLoader( |
| combined, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| collate_fn=lambda b: collate_fn(b, pad_id=tokenizer.eos_token_id), |
| drop_last=True, |
| ) |
|
|
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.lr, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| weight_decay=0.05, |
| ) |
|
|
| |
| step = 0 |
| tokens_seen = 0 |
| running_loss = 0.0 |
| t0 = time.time() |
| log_t0 = time.time() |
| epoch = 0 |
|
|
| print(f"\nSFT Training for {args.max_steps} steps") |
| print(f" Batch: {args.batch_size} x {args.grad_accum} accum = {args.batch_size * args.grad_accum}") |
| print(f" LR: {args.lr} -> {args.min_lr}") |
| print(f" bf16: {use_bf16}\n") |
|
|
| optimizer.zero_grad() |
|
|
| while step < args.max_steps: |
| epoch += 1 |
| print(f"--- Epoch {epoch} ---") |
|
|
| for batch in loader: |
| if step >= args.max_steps: |
| break |
|
|
| input_ids = batch["input_ids"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| lr = get_lr(step, args.warmup_steps, args.max_steps, args.lr, args.min_lr) |
| for g in optimizer.param_groups: |
| g["lr"] = lr |
|
|
| n_loops = sample_loop_depth(cfg.max_loop_iters, input_ids.shape[0]) |
|
|
| with torch.autocast(device_type="cuda", dtype=dtype, enabled=use_bf16): |
| logits = model(input_ids, n_loops=n_loops) |
| loss = F.cross_entropy( |
| logits.view(-1, cfg.vocab_size), |
| labels.view(-1), |
| ignore_index=-100, |
| ) |
| loss_scaled = loss / args.grad_accum |
|
|
| loss_scaled.backward() |
| running_loss += loss.item() |
| tokens_seen += (labels != -100).sum().item() |
|
|
| if (step + 1) % args.grad_accum == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| step += 1 |
|
|
| if step % args.log_interval == 0: |
| avg = running_loss / args.log_interval |
| ppl = math.exp(min(avg, 20)) |
| rho = model.get_spectral_radius() |
| dt = time.time() - log_t0 |
| print(f"step {step:>5d}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:.1f} | " |
| f"lr {lr:.2e} | rho(A) {rho:.4f} | depth {n_loops} | " |
| f"{tokens_seen:,} tokens | {dt:.1f}s") |
| running_loss = 0.0 |
| log_t0 = time.time() |
|
|
| if step % args.save_interval == 0 and step > 0: |
| save_checkpoint(model, cfg, step, tokens_seen, args) |
|
|
| |
| elapsed = time.time() - t0 |
| print(f"\nSFT complete! {step} steps in {elapsed:.0f}s ({elapsed/3600:.1f}h)") |
| print(f"Final rho(A): {model.get_spectral_radius():.6f}") |
|
|
| final = { |
| "step": step, |
| "tokens_seen": tokens_seen, |
| "model_state_dict": model.state_dict(), |
| "config": asdict(cfg), |
| "training": "sft", |
| "base_model": args.base_model, |
| } |
| final_path = "ultron_sec_final.pt" |
| torch.save(final, final_path) |
| print(f"Saved: {final_path}") |
|
|
| if args.hub_model_id: |
| try: |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=final_path, |
| path_in_repo="ultron_sec_final.pt", |
| repo_id=args.hub_model_id, |
| ) |
| config_path = "config.json" |
| with open(config_path, "w") as f: |
| json.dump(asdict(cfg), f, indent=2, default=str) |
| api.upload_file( |
| path_or_fileobj=config_path, |
| path_in_repo="config.json", |
| repo_id=args.hub_model_id, |
| ) |
| print(f"Pushed to {args.hub_model_id}") |
| except Exception as e: |
| print(f"Push failed: {e}") |
|
|
| print("Done!") |
|
|
|
|
| def save_checkpoint(model, cfg, step, tokens_seen, args): |
| ckpt = { |
| "step": step, |
| "tokens_seen": tokens_seen, |
| "model_state_dict": model.state_dict(), |
| "config": asdict(cfg), |
| } |
| path = f"ultron_sec_sft_step{step}.pt" |
| torch.save(ckpt, path) |
| print(f" Checkpoint: {path}") |
| if args.hub_model_id: |
| try: |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=path, |
| path_in_repo=f"checkpoints/{path}", |
| repo_id=args.hub_model_id, |
| ) |
| print(f" Pushed to {args.hub_model_id}") |
| except Exception as e: |
| print(f" Push failed: {e}") |
| if os.path.exists(path): |
| os.remove(path) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Ultron-Sec SFT Training") |
| parser.add_argument("--base_model", type=str, default="trojan0x/ultron-sec-cpt", |
| help="CPT model to fine-tune from") |
| parser.add_argument("--hub_model_id", type=str, |
| default=os.environ.get("HUB_MODEL_ID", "trojan0x/ultron-sec"), |
| help="Where to push the SFT model") |
| parser.add_argument("--max_steps", type=int, default=3000) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--grad_accum", type=int, default=8) |
| parser.add_argument("--lr", type=float, default=2.5e-5) |
| parser.add_argument("--min_lr", type=float, default=2.5e-6) |
| parser.add_argument("--warmup_steps", type=int, default=100) |
| parser.add_argument("--log_interval", type=int, default=10) |
| parser.add_argument("--save_interval", type=int, default=1000) |
| parser.add_argument("--data_limit", type=int, default=None, |
| help="Limit examples per dataset (for testing)") |
| args = parser.parse_args() |
| train(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|