#!/usr/bin/env python3 """ simple_agent.py - Simplified action-prediction agent. Instead of predicting exact (op, size, slot) in one shot from 128 actions, decompose into: Step 1: Predict operation type (4: malloc, free, write_freed, noop) Step 2: For the chosen op, pick parameters using simple heuristics This makes imitation learning tractable (4-class instead of 128-class). The model learns the HIGH-LEVEL STRATEGY (when to alloc vs free vs UAF-write), not the low-level details (which slot, which size). """ import sys import torch import torch.nn.functional as F import numpy as np import random import copy from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "simulator")) sys.path.insert(0, str(Path(__file__).parent.parent / "model")) from heap_sim import HeapSimulator from trm_heap import RMSNorm, SwiGLU, RecursionBlock # 4 operation types OP_MALLOC = 0 OP_FREE = 1 OP_WRITE_FREED = 2 OP_NOOP = 3 N_OPS = 4 SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80] class SimpleHeapTRM(torch.nn.Module): """TRM that predicts operation type (4 classes).""" def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, n_outer=2, n_inner=3): super().__init__() self.embed = torch.nn.Embedding(vocab_size, hidden_dim) self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.block_z = RecursionBlock(hidden_dim) self.block_y = RecursionBlock(hidden_dim) self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.out_norm = RMSNorm(hidden_dim) self.n_outer = n_outer self.n_inner = n_inner # 4-class output self.head = torch.nn.Linear(hidden_dim, N_OPS) def forward(self, x): B = x.shape[0] h = self.embed(x.reshape(B, -1)) + self.pos_embed y = self.y_init.expand(B, -1, -1) z = self.z_init.expand(B, -1, -1) for _ in range(self.n_outer): for _ in range(self.n_inner): z = z + self.block_z(h + y + z) y = y + self.block_y(y + z) pooled = self.out_norm(y).mean(dim=1) return self.head(pooled) def execute_op(sim: HeapSimulator, op_type: int) -> bool: """Execute an operation with automatic parameter selection.""" used_slots = set(sim.slots.keys()) free_slots = [s for s in range(8) if s not in used_slots] occupied_slots = list(used_slots) # Find freed-but-still-tracked slots (UAF candidates) freed_slots = [] for slot in occupied_slots: addr = sim.slots[slot] chunk_addr = addr - 16 chunk = sim.chunks.get(chunk_addr) if chunk and not chunk.allocated: freed_slots.append(slot) alloc_slots = [s for s in occupied_slots if s not in freed_slots] if op_type == OP_MALLOC: if not free_slots: return False slot = random.choice(free_slots) size = random.choice(SIZES) return sim.malloc(size, slot=slot) is not None elif op_type == OP_FREE: if not alloc_slots: return False slot = random.choice(alloc_slots) return sim.free(user_addr=sim.slots[slot], slot=slot) elif op_type == OP_WRITE_FREED: if not freed_slots or len(occupied_slots) < 2: return False src = random.choice(freed_slots) targets = [s for s in occupied_slots if s != src] if not targets: return False tgt = random.choice(targets) return sim.write_to_freed(sim.slots[src], sim.slots[tgt]) return False # ============================================================ # EXPERT DEMOS (as operation type sequences) # ============================================================ TCACHE_POISON_OPS = [ OP_MALLOC, # alloc A OP_MALLOC, # alloc B OP_MALLOC, # alloc C (guard) OP_FREE, # free A -> tcache OP_FREE, # free B -> tcache OP_WRITE_FREED, # UAF: corrupt B's fd OP_MALLOC, # alloc from tcache (gets B) OP_MALLOC, # alloc from tcache (gets poisoned addr!) ] def generate_demos(n_variants=200) -> tuple: """Generate demo (state, op_type) pairs by running through simulator.""" states = [] labels = [] for _ in range(n_variants): sim = HeapSimulator() ops = list(TCACHE_POISON_OPS) for op in ops: grid = sim.state_to_grid() states.append(grid) labels.append(op) execute_op(sim, op) return np.stack(states), np.array(labels, dtype=np.int64) def train(model, X, y, epochs=200, lr=1e-3, bs=64): """Train with cross-entropy on 4-class op prediction.""" opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) X_t = torch.from_numpy(X).long() y_t = torch.from_numpy(y).long() n = len(X_t) for ep in range(1, epochs + 1): model.train() perm = torch.randperm(n) total_loss = 0 correct = 0 nb = 0 for i in range(0, n, bs): idx = perm[i:i+bs] logits = model(X_t[idx]) loss = F.cross_entropy(logits, y_t[idx]) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total_loss += loss.item() correct += (logits.argmax(1) == y_t[idx]).sum().item() nb += 1 if ep % 20 == 0 or ep == 1: acc = correct / n print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={acc:.3f}") def evaluate_policy(model, goal="duplicate_alloc", n_trials=200, max_steps=20): """Run the learned policy and check success rate.""" model.eval() n_achieved = 0 lengths = [] for _ in range(n_trials): sim = HeapSimulator() for step in range(max_steps): grid = sim.state_to_grid() x = torch.from_numpy(grid).long().unsqueeze(0) with torch.no_grad(): logits = model(x) op = logits.argmax(1).item() execute_op(sim, op) prims = sim.check_primitives() if prims.get(goal, False): n_achieved += 1 lengths.append(step + 1) break return n_achieved, lengths def main(): print("=== Generating demos ===") X, y = generate_demos(300) print(f"Data: {len(X)} samples") print(f"Op distribution: malloc={sum(y==0)}, free={sum(y==1)}, " f"write_freed={sum(y==2)}, noop={sum(y==3)}") print("\n=== Training (4-class op prediction) ===") model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3) params = sum(p.numel() for p in model.parameters()) print(f"Parameters: {params:,}") train(model, X, y, epochs=200, lr=1e-3) print("\n=== Evaluating policy ===") n_achieved, lengths = evaluate_policy(model, n_trials=200, max_steps=20) print(f"Success rate: {n_achieved}/200 ({n_achieved/2:.0f}%)") if lengths: print(f"Avg steps: {np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}") # Also try with beam-like approach: run 10x and take best print("\n=== Best-of-10 evaluation ===") n_achieved_bo10 = 0 for trial in range(100): found = False for attempt in range(10): sim = HeapSimulator() for step in range(20): grid = sim.state_to_grid() x = torch.from_numpy(grid).long().unsqueeze(0) with torch.no_grad(): logits = model(x) probs = F.softmax(logits / 0.5, dim=1) op = torch.multinomial(probs, 1).item() execute_op(sim, op) if sim.check_primitives().get("duplicate_alloc"): found = True break if found: break if found: n_achieved_bo10 += 1 print(f"Best-of-10 success: {n_achieved_bo10}/100 ({n_achieved_bo10}%)") if __name__ == "__main__": main()