#!/usr/bin/env python3 """ train_universal.py - Train with allocator-agnostic grid on GPU. Uses UniversalGridEncoder (relationship-based features) and trains on RTX 4090. """ import sys import os import json import subprocess import tempfile import time import random import re import numpy as np import torch import torch.nn.functional as F from pathlib import Path from collections import deque ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT / "agent")) sys.path.insert(0, str(ROOT / "model")) from simple_agent import SimpleHeapTRM, OP_MALLOC, OP_FREE, OP_WRITE_FREED, SIZES from universal_grid import UniversalGridEncoder BINARY = ROOT / "ctf" / "vuln_heap" HARNESS = ROOT / "harness" / "heapgrid_harness.so" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def run_and_dump(commands): dump_path = tempfile.mktemp(suffix=".jsonl") env = os.environ.copy() env["LD_PRELOAD"] = str(HARNESS) env["HEAPGRID_OUT"] = dump_path input_str = "\n".join(commands) + "\n5\n" subprocess.run([str(BINARY)], input=input_str.encode(), env=env, capture_output=True, timeout=10) states = [] if os.path.exists(dump_path): with open(dump_path) as f: for line in f: if line.strip(): states.append(json.loads(line.strip())) os.unlink(dump_path) return states def gen_exploit(n_pre=None): size = random.choice(SIZES) slots = random.sample(range(16), 8) if n_pre is None: n_pre = random.randint(2, 6) commands = [] labels = [] alloc_slots = [] for i in range(n_pre): commands.append(f"1 {slots[i]} {size}") labels.append(OP_MALLOC) alloc_slots.append(slots[i]) free_targets = random.sample(alloc_slots, 2) for s in free_targets: commands.append(f"4 {s}") labels.append(OP_FREE) commands.append(f"2 {free_targets[1]} {'41' * 8}") labels.append(OP_WRITE_FREED) drain = [s for s in slots if s not in alloc_slots][:2] for s in drain: commands.append(f"1 {s} {size}") labels.append(OP_MALLOC) return commands, labels def gen_benign(): n_ops = random.randint(4, 15) commands, labels = [], [] allocated = {} for _ in range(n_ops): if not allocated or random.random() < 0.55: free_slots = [s for s in range(16) if s not in allocated] if not free_slots: break s = random.choice(free_slots) sz = random.choice(SIZES) commands.append(f"1 {s} {sz}") labels.append(OP_MALLOC) allocated[s] = sz else: s = random.choice(list(allocated.keys())) commands.append(f"4 {s}") labels.append(OP_FREE) del allocated[s] return commands, labels def collect_data(n_exploit=500, n_benign=250): all_grids, all_labels = [], [] print(f" Generating {n_exploit} exploit sequences...") for i in range(n_exploit): commands, op_labels = gen_exploit() states = run_and_dump(commands) if len(states) < 3: continue encoder = UniversalGridEncoder() n_free_seen = 0 write_done = False for j, state in enumerate(states): op = state.get("operation", "malloc") if "free" in op.lower(): n_free_seen += 1 chunks = state.get("chunks", []) n_alloc = sum(1 for c in chunks if c.get("state") == 1) n_freed = sum(1 for c in chunks if c.get("state") == 2) if n_free_seen == 0: next_label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC elif n_free_seen >= 2 and not write_done: next_label = OP_WRITE_FREED write_done = True else: next_label = OP_MALLOC grid = encoder.encode(state) all_grids.append(grid) all_labels.append(next_label) actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC encoder.record_action(actual_op, state.get("target_size", 0)) print(f" Generating {n_benign} benign sequences...") for i in range(n_benign): commands, _ = gen_benign() states = run_and_dump(commands) encoder = UniversalGridEncoder() for j, state in enumerate(states): op = state.get("operation", "malloc") if j + 1 < len(states): next_op = states[j + 1].get("operation", "malloc") next_label = OP_FREE if "free" in next_op.lower() else OP_MALLOC else: next_label = OP_MALLOC grid = encoder.encode(state) all_grids.append(grid) all_labels.append(next_label) actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC encoder.record_action(actual_op, state.get("target_size", 0)) return np.stack(all_grids), np.array(all_labels, dtype=np.int64) def train_gpu(model, X, y, epochs=200, lr=1e-3, bs=128): model.to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) X_t = torch.from_numpy(X).long().to(DEVICE) y_t = torch.from_numpy(y).long().to(DEVICE) n = len(X_t) for ep in range(1, epochs + 1): model.train() perm = torch.randperm(n, device=DEVICE) total_loss = 0 correct = 0 nb = 0 for i in range(0, n, bs): idx = perm[i:i+bs] logits = model(X_t[idx]) loss = F.cross_entropy(logits, y_t[idx]) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total_loss += loss.item() correct += (logits.argmax(1) == y_t[idx]).sum().item() nb += 1 sched.step() if ep % 20 == 0 or ep == 1: print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}") def evaluate_live(model, n_trials=50, max_steps=20, temperature=0.3): model.eval() model.to(DEVICE) n_uaf = 0 n_correct_seq = 0 seqs = [] for trial in range(n_trials): dump_path = tempfile.mktemp(suffix=".jsonl") env = os.environ.copy() env["LD_PRELOAD"] = str(HARNESS) env["HEAPGRID_OUT"] = dump_path proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) encoder = UniversalGridEncoder() slots = {} freed = set() ops = [] did_write = False n_frees_since_alloc = 0 for step in range(max_steps): time.sleep(0.01) state = {"chunks": []} try: with open(dump_path) as f: lines = f.readlines() if lines: state = json.loads(lines[-1].strip()) except: pass grid = encoder.encode(state) # Hybrid: rule triggers W after 2+ frees if freed and not did_write and n_frees_since_alloc >= 2: op = OP_WRITE_FREED else: x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(x) if temperature == 0: op = logits.argmax(1).item() else: probs = F.softmax(logits / temperature, dim=1) op = torch.multinomial(probs, 1).item() free_slots = [s for s in range(8) if s not in slots] alloc_slots = [s for s, v in slots.items() if v and s not in freed] cmd = None size = 0 if op == OP_MALLOC and free_slots: s = random.choice(free_slots) size = random.choice(SIZES) cmd = f"1 {s} {size}" slots[s] = True ops.append("M") n_frees_since_alloc = 0 elif op == OP_FREE and alloc_slots: s = random.choice(alloc_slots) cmd = f"4 {s}" slots[s] = False freed.add(s) ops.append("F") n_frees_since_alloc += 1 elif op == OP_WRITE_FREED and freed: s = random.choice(list(freed)) cmd = f"2 {s} {'41' * 8}" did_write = True ops.append("W") else: ops.append("x") if cmd: proc.stdin.write((cmd + "\n").encode()) proc.stdin.flush() encoder.record_action(op, size) try: proc.stdin.write(b"5\n"); proc.stdin.flush() proc.wait(timeout=2) except: proc.kill() if did_write: try: with open(dump_path) as f: for line in f: state = json.loads(line.strip()) for c in state.get("chunks", []): if c.get("fd", 0) == 0x4141414141414141: n_uaf += 1 break else: continue break except: pass seq = "".join(ops) if re.match(r"M+F+WM*", seq.replace("x", "")): n_correct_seq += 1 seqs.append(seq) if os.path.exists(dump_path): os.unlink(dump_path) return n_uaf, n_correct_seq, seqs def main(): print(f"Device: {DEVICE}") print("\n=== Collecting data ===") X, y = collect_data(n_exploit=500, n_benign=250) print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} W={sum(y==2)}") print("\n=== Training on GPU ===") model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3) print(f"Params: {sum(p.numel() for p in model.parameters()):,}") train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128) print("\n=== Live evaluation (hybrid, 50 trials) ===") n_uaf, n_seq, seqs = evaluate_live(model, n_trials=50, temperature=0.3) print(f"UAF writes: {n_uaf}/50 ({n_uaf*2}%)") print(f"Correct sequences: {n_seq}/50 ({n_seq*2}%)") print("Samples:") for s in seqs[:10]: print(f" {s}") print("\n=== Best-of-5 (30 trials) ===") n_bo5 = 0 for trial in range(30): found = False for attempt in range(5): nw, _, _ = evaluate_live(model, n_trials=1, temperature=0.5) if nw > 0: found = True break if found: n_bo5 += 1 print(f"Best-of-5: {n_bo5}/30 ({n_bo5/30*100:.0f}%)") if __name__ == "__main__": main()