heap-trm / agent /simple_agent.py
amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
#!/usr/bin/env python3
"""
simple_agent.py - Simplified action-prediction agent.
Instead of predicting exact (op, size, slot) in one shot from 128 actions,
decompose into:
Step 1: Predict operation type (4: malloc, free, write_freed, noop)
Step 2: For the chosen op, pick parameters using simple heuristics
This makes imitation learning tractable (4-class instead of 128-class).
The model learns the HIGH-LEVEL STRATEGY (when to alloc vs free vs UAF-write),
not the low-level details (which slot, which size).
"""
import sys
import torch
import torch.nn.functional as F
import numpy as np
import random
import copy
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "simulator"))
sys.path.insert(0, str(Path(__file__).parent.parent / "model"))
from heap_sim import HeapSimulator
from trm_heap import RMSNorm, SwiGLU, RecursionBlock
# 4 operation types
OP_MALLOC = 0
OP_FREE = 1
OP_WRITE_FREED = 2
OP_NOOP = 3
N_OPS = 4
SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80]
class SimpleHeapTRM(torch.nn.Module):
"""TRM that predicts operation type (4 classes)."""
def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
n_outer=2, n_inner=3):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_dim)
self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.block_z = RecursionBlock(hidden_dim)
self.block_y = RecursionBlock(hidden_dim)
self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.out_norm = RMSNorm(hidden_dim)
self.n_outer = n_outer
self.n_inner = n_inner
# 4-class output
self.head = torch.nn.Linear(hidden_dim, N_OPS)
def forward(self, x):
B = x.shape[0]
h = self.embed(x.reshape(B, -1)) + self.pos_embed
y = self.y_init.expand(B, -1, -1)
z = self.z_init.expand(B, -1, -1)
for _ in range(self.n_outer):
for _ in range(self.n_inner):
z = z + self.block_z(h + y + z)
y = y + self.block_y(y + z)
pooled = self.out_norm(y).mean(dim=1)
return self.head(pooled)
def execute_op(sim: HeapSimulator, op_type: int) -> bool:
"""Execute an operation with automatic parameter selection."""
used_slots = set(sim.slots.keys())
free_slots = [s for s in range(8) if s not in used_slots]
occupied_slots = list(used_slots)
# Find freed-but-still-tracked slots (UAF candidates)
freed_slots = []
for slot in occupied_slots:
addr = sim.slots[slot]
chunk_addr = addr - 16
chunk = sim.chunks.get(chunk_addr)
if chunk and not chunk.allocated:
freed_slots.append(slot)
alloc_slots = [s for s in occupied_slots if s not in freed_slots]
if op_type == OP_MALLOC:
if not free_slots:
return False
slot = random.choice(free_slots)
size = random.choice(SIZES)
return sim.malloc(size, slot=slot) is not None
elif op_type == OP_FREE:
if not alloc_slots:
return False
slot = random.choice(alloc_slots)
return sim.free(user_addr=sim.slots[slot], slot=slot)
elif op_type == OP_WRITE_FREED:
if not freed_slots or len(occupied_slots) < 2:
return False
src = random.choice(freed_slots)
targets = [s for s in occupied_slots if s != src]
if not targets:
return False
tgt = random.choice(targets)
return sim.write_to_freed(sim.slots[src], sim.slots[tgt])
return False
# ============================================================
# EXPERT DEMOS (as operation type sequences)
# ============================================================
TCACHE_POISON_OPS = [
OP_MALLOC, # alloc A
OP_MALLOC, # alloc B
OP_MALLOC, # alloc C (guard)
OP_FREE, # free A -> tcache
OP_FREE, # free B -> tcache
OP_WRITE_FREED, # UAF: corrupt B's fd
OP_MALLOC, # alloc from tcache (gets B)
OP_MALLOC, # alloc from tcache (gets poisoned addr!)
]
def generate_demos(n_variants=200) -> tuple:
"""Generate demo (state, op_type) pairs by running through simulator."""
states = []
labels = []
for _ in range(n_variants):
sim = HeapSimulator()
ops = list(TCACHE_POISON_OPS)
for op in ops:
grid = sim.state_to_grid()
states.append(grid)
labels.append(op)
execute_op(sim, op)
return np.stack(states), np.array(labels, dtype=np.int64)
def train(model, X, y, epochs=200, lr=1e-3, bs=64):
"""Train with cross-entropy on 4-class op prediction."""
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
X_t = torch.from_numpy(X).long()
y_t = torch.from_numpy(y).long()
n = len(X_t)
for ep in range(1, epochs + 1):
model.train()
perm = torch.randperm(n)
total_loss = 0
correct = 0
nb = 0
for i in range(0, n, bs):
idx = perm[i:i+bs]
logits = model(X_t[idx])
loss = F.cross_entropy(logits, y_t[idx])
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y_t[idx]).sum().item()
nb += 1
if ep % 20 == 0 or ep == 1:
acc = correct / n
print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={acc:.3f}")
def evaluate_policy(model, goal="duplicate_alloc", n_trials=200, max_steps=20):
"""Run the learned policy and check success rate."""
model.eval()
n_achieved = 0
lengths = []
for _ in range(n_trials):
sim = HeapSimulator()
for step in range(max_steps):
grid = sim.state_to_grid()
x = torch.from_numpy(grid).long().unsqueeze(0)
with torch.no_grad():
logits = model(x)
op = logits.argmax(1).item()
execute_op(sim, op)
prims = sim.check_primitives()
if prims.get(goal, False):
n_achieved += 1
lengths.append(step + 1)
break
return n_achieved, lengths
def main():
print("=== Generating demos ===")
X, y = generate_demos(300)
print(f"Data: {len(X)} samples")
print(f"Op distribution: malloc={sum(y==0)}, free={sum(y==1)}, "
f"write_freed={sum(y==2)}, noop={sum(y==3)}")
print("\n=== Training (4-class op prediction) ===")
model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3)
params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {params:,}")
train(model, X, y, epochs=200, lr=1e-3)
print("\n=== Evaluating policy ===")
n_achieved, lengths = evaluate_policy(model, n_trials=200, max_steps=20)
print(f"Success rate: {n_achieved}/200 ({n_achieved/2:.0f}%)")
if lengths:
print(f"Avg steps: {np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")
# Also try with beam-like approach: run 10x and take best
print("\n=== Best-of-10 evaluation ===")
n_achieved_bo10 = 0
for trial in range(100):
found = False
for attempt in range(10):
sim = HeapSimulator()
for step in range(20):
grid = sim.state_to_grid()
x = torch.from_numpy(grid).long().unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits / 0.5, dim=1)
op = torch.multinomial(probs, 1).item()
execute_op(sim, op)
if sim.check_primitives().get("duplicate_alloc"):
found = True
break
if found:
break
if found:
n_achieved_bo10 += 1
print(f"Best-of-10 success: {n_achieved_bo10}/100 ({n_achieved_bo10}%)")
if __name__ == "__main__":
main()