| |
| """ |
| train_universal.py - Train with allocator-agnostic grid on GPU. |
| |
| Uses UniversalGridEncoder (relationship-based features) and trains on RTX 4090. |
| """ |
|
|
| import sys |
| import os |
| import json |
| import subprocess |
| import tempfile |
| import time |
| import random |
| import re |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
| from collections import deque |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT / "agent")) |
| sys.path.insert(0, str(ROOT / "model")) |
|
|
| from simple_agent import SimpleHeapTRM, OP_MALLOC, OP_FREE, OP_WRITE_FREED, SIZES |
| from universal_grid import UniversalGridEncoder |
|
|
| BINARY = ROOT / "ctf" / "vuln_heap" |
| HARNESS = ROOT / "harness" / "heapgrid_harness.so" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def run_and_dump(commands): |
| dump_path = tempfile.mktemp(suffix=".jsonl") |
| env = os.environ.copy() |
| env["LD_PRELOAD"] = str(HARNESS) |
| env["HEAPGRID_OUT"] = dump_path |
| input_str = "\n".join(commands) + "\n5\n" |
| subprocess.run([str(BINARY)], input=input_str.encode(), |
| env=env, capture_output=True, timeout=10) |
| states = [] |
| if os.path.exists(dump_path): |
| with open(dump_path) as f: |
| for line in f: |
| if line.strip(): |
| states.append(json.loads(line.strip())) |
| os.unlink(dump_path) |
| return states |
|
|
|
|
| def gen_exploit(n_pre=None): |
| size = random.choice(SIZES) |
| slots = random.sample(range(16), 8) |
| if n_pre is None: |
| n_pre = random.randint(2, 6) |
| commands = [] |
| labels = [] |
| alloc_slots = [] |
| for i in range(n_pre): |
| commands.append(f"1 {slots[i]} {size}") |
| labels.append(OP_MALLOC) |
| alloc_slots.append(slots[i]) |
| free_targets = random.sample(alloc_slots, 2) |
| for s in free_targets: |
| commands.append(f"4 {s}") |
| labels.append(OP_FREE) |
| commands.append(f"2 {free_targets[1]} {'41' * 8}") |
| labels.append(OP_WRITE_FREED) |
| drain = [s for s in slots if s not in alloc_slots][:2] |
| for s in drain: |
| commands.append(f"1 {s} {size}") |
| labels.append(OP_MALLOC) |
| return commands, labels |
|
|
|
|
| def gen_benign(): |
| n_ops = random.randint(4, 15) |
| commands, labels = [], [] |
| allocated = {} |
| for _ in range(n_ops): |
| if not allocated or random.random() < 0.55: |
| free_slots = [s for s in range(16) if s not in allocated] |
| if not free_slots: |
| break |
| s = random.choice(free_slots) |
| sz = random.choice(SIZES) |
| commands.append(f"1 {s} {sz}") |
| labels.append(OP_MALLOC) |
| allocated[s] = sz |
| else: |
| s = random.choice(list(allocated.keys())) |
| commands.append(f"4 {s}") |
| labels.append(OP_FREE) |
| del allocated[s] |
| return commands, labels |
|
|
|
|
| def collect_data(n_exploit=500, n_benign=250): |
| all_grids, all_labels = [], [] |
|
|
| print(f" Generating {n_exploit} exploit sequences...") |
| for i in range(n_exploit): |
| commands, op_labels = gen_exploit() |
| states = run_and_dump(commands) |
| if len(states) < 3: |
| continue |
|
|
| encoder = UniversalGridEncoder() |
| n_free_seen = 0 |
| write_done = False |
|
|
| for j, state in enumerate(states): |
| op = state.get("operation", "malloc") |
| if "free" in op.lower(): |
| n_free_seen += 1 |
|
|
| chunks = state.get("chunks", []) |
| n_alloc = sum(1 for c in chunks if c.get("state") == 1) |
| n_freed = sum(1 for c in chunks if c.get("state") == 2) |
|
|
| if n_free_seen == 0: |
| next_label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC |
| elif n_free_seen >= 2 and not write_done: |
| next_label = OP_WRITE_FREED |
| write_done = True |
| else: |
| next_label = OP_MALLOC |
|
|
| grid = encoder.encode(state) |
| all_grids.append(grid) |
| all_labels.append(next_label) |
|
|
| actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC |
| encoder.record_action(actual_op, state.get("target_size", 0)) |
|
|
| print(f" Generating {n_benign} benign sequences...") |
| for i in range(n_benign): |
| commands, _ = gen_benign() |
| states = run_and_dump(commands) |
| encoder = UniversalGridEncoder() |
| for j, state in enumerate(states): |
| op = state.get("operation", "malloc") |
| if j + 1 < len(states): |
| next_op = states[j + 1].get("operation", "malloc") |
| next_label = OP_FREE if "free" in next_op.lower() else OP_MALLOC |
| else: |
| next_label = OP_MALLOC |
| grid = encoder.encode(state) |
| all_grids.append(grid) |
| all_labels.append(next_label) |
| actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC |
| encoder.record_action(actual_op, state.get("target_size", 0)) |
|
|
| return np.stack(all_grids), np.array(all_labels, dtype=np.int64) |
|
|
|
|
| def train_gpu(model, X, y, epochs=200, lr=1e-3, bs=128): |
| model.to(DEVICE) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| X_t = torch.from_numpy(X).long().to(DEVICE) |
| y_t = torch.from_numpy(y).long().to(DEVICE) |
| n = len(X_t) |
|
|
| for ep in range(1, epochs + 1): |
| model.train() |
| perm = torch.randperm(n, device=DEVICE) |
| total_loss = 0 |
| correct = 0 |
| nb = 0 |
| for i in range(0, n, bs): |
| idx = perm[i:i+bs] |
| logits = model(X_t[idx]) |
| loss = F.cross_entropy(logits, y_t[idx]) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| total_loss += loss.item() |
| correct += (logits.argmax(1) == y_t[idx]).sum().item() |
| nb += 1 |
| sched.step() |
| if ep % 20 == 0 or ep == 1: |
| print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}") |
|
|
|
|
| def evaluate_live(model, n_trials=50, max_steps=20, temperature=0.3): |
| model.eval() |
| model.to(DEVICE) |
| n_uaf = 0 |
| n_correct_seq = 0 |
| seqs = [] |
|
|
| for trial in range(n_trials): |
| dump_path = tempfile.mktemp(suffix=".jsonl") |
| env = os.environ.copy() |
| env["LD_PRELOAD"] = str(HARNESS) |
| env["HEAPGRID_OUT"] = dump_path |
| proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE, |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) |
|
|
| encoder = UniversalGridEncoder() |
| slots = {} |
| freed = set() |
| ops = [] |
| did_write = False |
| n_frees_since_alloc = 0 |
|
|
| for step in range(max_steps): |
| time.sleep(0.01) |
| state = {"chunks": []} |
| try: |
| with open(dump_path) as f: |
| lines = f.readlines() |
| if lines: |
| state = json.loads(lines[-1].strip()) |
| except: |
| pass |
|
|
| grid = encoder.encode(state) |
|
|
| |
| if freed and not did_write and n_frees_since_alloc >= 2: |
| op = OP_WRITE_FREED |
| else: |
| x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| logits = model(x) |
| if temperature == 0: |
| op = logits.argmax(1).item() |
| else: |
| probs = F.softmax(logits / temperature, dim=1) |
| op = torch.multinomial(probs, 1).item() |
|
|
| free_slots = [s for s in range(8) if s not in slots] |
| alloc_slots = [s for s, v in slots.items() if v and s not in freed] |
|
|
| cmd = None |
| size = 0 |
| if op == OP_MALLOC and free_slots: |
| s = random.choice(free_slots) |
| size = random.choice(SIZES) |
| cmd = f"1 {s} {size}" |
| slots[s] = True |
| ops.append("M") |
| n_frees_since_alloc = 0 |
| elif op == OP_FREE and alloc_slots: |
| s = random.choice(alloc_slots) |
| cmd = f"4 {s}" |
| slots[s] = False |
| freed.add(s) |
| ops.append("F") |
| n_frees_since_alloc += 1 |
| elif op == OP_WRITE_FREED and freed: |
| s = random.choice(list(freed)) |
| cmd = f"2 {s} {'41' * 8}" |
| did_write = True |
| ops.append("W") |
| else: |
| ops.append("x") |
|
|
| if cmd: |
| proc.stdin.write((cmd + "\n").encode()) |
| proc.stdin.flush() |
| encoder.record_action(op, size) |
|
|
| try: |
| proc.stdin.write(b"5\n"); proc.stdin.flush() |
| proc.wait(timeout=2) |
| except: |
| proc.kill() |
|
|
| if did_write: |
| try: |
| with open(dump_path) as f: |
| for line in f: |
| state = json.loads(line.strip()) |
| for c in state.get("chunks", []): |
| if c.get("fd", 0) == 0x4141414141414141: |
| n_uaf += 1 |
| break |
| else: |
| continue |
| break |
| except: |
| pass |
|
|
| seq = "".join(ops) |
| if re.match(r"M+F+WM*", seq.replace("x", "")): |
| n_correct_seq += 1 |
| seqs.append(seq) |
| if os.path.exists(dump_path): |
| os.unlink(dump_path) |
|
|
| return n_uaf, n_correct_seq, seqs |
|
|
|
|
| def main(): |
| print(f"Device: {DEVICE}") |
|
|
| print("\n=== Collecting data ===") |
| X, y = collect_data(n_exploit=500, n_benign=250) |
| print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} W={sum(y==2)}") |
|
|
| print("\n=== Training on GPU ===") |
| model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3) |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}") |
| train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128) |
|
|
| print("\n=== Live evaluation (hybrid, 50 trials) ===") |
| n_uaf, n_seq, seqs = evaluate_live(model, n_trials=50, temperature=0.3) |
| print(f"UAF writes: {n_uaf}/50 ({n_uaf*2}%)") |
| print(f"Correct sequences: {n_seq}/50 ({n_seq*2}%)") |
| print("Samples:") |
| for s in seqs[:10]: |
| print(f" {s}") |
|
|
| print("\n=== Best-of-5 (30 trials) ===") |
| n_bo5 = 0 |
| for trial in range(30): |
| found = False |
| for attempt in range(5): |
| nw, _, _ = evaluate_live(model, n_trials=1, temperature=0.5) |
| if nw > 0: |
| found = True |
| break |
| if found: |
| n_bo5 += 1 |
| print(f"Best-of-5: {n_bo5}/30 ({n_bo5/30*100:.0f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|