| |
| """ |
| simple_agent.py - Simplified action-prediction agent. |
| |
| Instead of predicting exact (op, size, slot) in one shot from 128 actions, |
| decompose into: |
| Step 1: Predict operation type (4: malloc, free, write_freed, noop) |
| Step 2: For the chosen op, pick parameters using simple heuristics |
| |
| This makes imitation learning tractable (4-class instead of 128-class). |
| The model learns the HIGH-LEVEL STRATEGY (when to alloc vs free vs UAF-write), |
| not the low-level details (which slot, which size). |
| """ |
|
|
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import random |
| import copy |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent / "simulator")) |
| sys.path.insert(0, str(Path(__file__).parent.parent / "model")) |
|
|
| from heap_sim import HeapSimulator |
| from trm_heap import RMSNorm, SwiGLU, RecursionBlock |
|
|
| |
| OP_MALLOC = 0 |
| OP_FREE = 1 |
| OP_WRITE_FREED = 2 |
| OP_NOOP = 3 |
| N_OPS = 4 |
|
|
| SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80] |
|
|
|
|
| class SimpleHeapTRM(torch.nn.Module): |
| """TRM that predicts operation type (4 classes).""" |
|
|
| def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, |
| n_outer=2, n_inner=3): |
| super().__init__() |
| self.embed = torch.nn.Embedding(vocab_size, hidden_dim) |
| self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.block_z = RecursionBlock(hidden_dim) |
| self.block_y = RecursionBlock(hidden_dim) |
| self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.out_norm = RMSNorm(hidden_dim) |
| self.n_outer = n_outer |
| self.n_inner = n_inner |
|
|
| |
| self.head = torch.nn.Linear(hidden_dim, N_OPS) |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| h = self.embed(x.reshape(B, -1)) + self.pos_embed |
| y = self.y_init.expand(B, -1, -1) |
| z = self.z_init.expand(B, -1, -1) |
| for _ in range(self.n_outer): |
| for _ in range(self.n_inner): |
| z = z + self.block_z(h + y + z) |
| y = y + self.block_y(y + z) |
| pooled = self.out_norm(y).mean(dim=1) |
| return self.head(pooled) |
|
|
|
|
| def execute_op(sim: HeapSimulator, op_type: int) -> bool: |
| """Execute an operation with automatic parameter selection.""" |
| used_slots = set(sim.slots.keys()) |
| free_slots = [s for s in range(8) if s not in used_slots] |
| occupied_slots = list(used_slots) |
|
|
| |
| freed_slots = [] |
| for slot in occupied_slots: |
| addr = sim.slots[slot] |
| chunk_addr = addr - 16 |
| chunk = sim.chunks.get(chunk_addr) |
| if chunk and not chunk.allocated: |
| freed_slots.append(slot) |
|
|
| alloc_slots = [s for s in occupied_slots if s not in freed_slots] |
|
|
| if op_type == OP_MALLOC: |
| if not free_slots: |
| return False |
| slot = random.choice(free_slots) |
| size = random.choice(SIZES) |
| return sim.malloc(size, slot=slot) is not None |
|
|
| elif op_type == OP_FREE: |
| if not alloc_slots: |
| return False |
| slot = random.choice(alloc_slots) |
| return sim.free(user_addr=sim.slots[slot], slot=slot) |
|
|
| elif op_type == OP_WRITE_FREED: |
| if not freed_slots or len(occupied_slots) < 2: |
| return False |
| src = random.choice(freed_slots) |
| targets = [s for s in occupied_slots if s != src] |
| if not targets: |
| return False |
| tgt = random.choice(targets) |
| return sim.write_to_freed(sim.slots[src], sim.slots[tgt]) |
|
|
| return False |
|
|
|
|
| |
| |
| |
|
|
| TCACHE_POISON_OPS = [ |
| OP_MALLOC, |
| OP_MALLOC, |
| OP_MALLOC, |
| OP_FREE, |
| OP_FREE, |
| OP_WRITE_FREED, |
| OP_MALLOC, |
| OP_MALLOC, |
| ] |
|
|
|
|
| def generate_demos(n_variants=200) -> tuple: |
| """Generate demo (state, op_type) pairs by running through simulator.""" |
| states = [] |
| labels = [] |
|
|
| for _ in range(n_variants): |
| sim = HeapSimulator() |
| ops = list(TCACHE_POISON_OPS) |
|
|
| for op in ops: |
| grid = sim.state_to_grid() |
| states.append(grid) |
| labels.append(op) |
| execute_op(sim, op) |
|
|
| return np.stack(states), np.array(labels, dtype=np.int64) |
|
|
|
|
| def train(model, X, y, epochs=200, lr=1e-3, bs=64): |
| """Train with cross-entropy on 4-class op prediction.""" |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| X_t = torch.from_numpy(X).long() |
| y_t = torch.from_numpy(y).long() |
| n = len(X_t) |
|
|
| for ep in range(1, epochs + 1): |
| model.train() |
| perm = torch.randperm(n) |
| total_loss = 0 |
| correct = 0 |
| nb = 0 |
|
|
| for i in range(0, n, bs): |
| idx = perm[i:i+bs] |
| logits = model(X_t[idx]) |
| loss = F.cross_entropy(logits, y_t[idx]) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| total_loss += loss.item() |
| correct += (logits.argmax(1) == y_t[idx]).sum().item() |
| nb += 1 |
|
|
| if ep % 20 == 0 or ep == 1: |
| acc = correct / n |
| print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={acc:.3f}") |
|
|
|
|
| def evaluate_policy(model, goal="duplicate_alloc", n_trials=200, max_steps=20): |
| """Run the learned policy and check success rate.""" |
| model.eval() |
| n_achieved = 0 |
| lengths = [] |
|
|
| for _ in range(n_trials): |
| sim = HeapSimulator() |
| for step in range(max_steps): |
| grid = sim.state_to_grid() |
| x = torch.from_numpy(grid).long().unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(x) |
| op = logits.argmax(1).item() |
|
|
| execute_op(sim, op) |
|
|
| prims = sim.check_primitives() |
| if prims.get(goal, False): |
| n_achieved += 1 |
| lengths.append(step + 1) |
| break |
|
|
| return n_achieved, lengths |
|
|
|
|
| def main(): |
| print("=== Generating demos ===") |
| X, y = generate_demos(300) |
| print(f"Data: {len(X)} samples") |
| print(f"Op distribution: malloc={sum(y==0)}, free={sum(y==1)}, " |
| f"write_freed={sum(y==2)}, noop={sum(y==3)}") |
|
|
| print("\n=== Training (4-class op prediction) ===") |
| model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3) |
| params = sum(p.numel() for p in model.parameters()) |
| print(f"Parameters: {params:,}") |
| train(model, X, y, epochs=200, lr=1e-3) |
|
|
| print("\n=== Evaluating policy ===") |
| n_achieved, lengths = evaluate_policy(model, n_trials=200, max_steps=20) |
| print(f"Success rate: {n_achieved}/200 ({n_achieved/2:.0f}%)") |
| if lengths: |
| print(f"Avg steps: {np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}") |
|
|
| |
| print("\n=== Best-of-10 evaluation ===") |
| n_achieved_bo10 = 0 |
| for trial in range(100): |
| found = False |
| for attempt in range(10): |
| sim = HeapSimulator() |
| for step in range(20): |
| grid = sim.state_to_grid() |
| x = torch.from_numpy(grid).long().unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(x) |
| probs = F.softmax(logits / 0.5, dim=1) |
| op = torch.multinomial(probs, 1).item() |
| execute_op(sim, op) |
| if sim.check_primitives().get("duplicate_alloc"): |
| found = True |
| break |
| if found: |
| break |
| if found: |
| n_achieved_bo10 += 1 |
|
|
| print(f"Best-of-10 success: {n_achieved_bo10}/100 ({n_achieved_bo10}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|