File size: 8,233 Bytes
22374d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | #!/usr/bin/env python3
"""
simple_agent.py - Simplified action-prediction agent.
Instead of predicting exact (op, size, slot) in one shot from 128 actions,
decompose into:
Step 1: Predict operation type (4: malloc, free, write_freed, noop)
Step 2: For the chosen op, pick parameters using simple heuristics
This makes imitation learning tractable (4-class instead of 128-class).
The model learns the HIGH-LEVEL STRATEGY (when to alloc vs free vs UAF-write),
not the low-level details (which slot, which size).
"""
import sys
import torch
import torch.nn.functional as F
import numpy as np
import random
import copy
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "simulator"))
sys.path.insert(0, str(Path(__file__).parent.parent / "model"))
from heap_sim import HeapSimulator
from trm_heap import RMSNorm, SwiGLU, RecursionBlock
# 4 operation types
OP_MALLOC = 0
OP_FREE = 1
OP_WRITE_FREED = 2
OP_NOOP = 3
N_OPS = 4
SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80]
class SimpleHeapTRM(torch.nn.Module):
"""TRM that predicts operation type (4 classes)."""
def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
n_outer=2, n_inner=3):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_dim)
self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.block_z = RecursionBlock(hidden_dim)
self.block_y = RecursionBlock(hidden_dim)
self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.out_norm = RMSNorm(hidden_dim)
self.n_outer = n_outer
self.n_inner = n_inner
# 4-class output
self.head = torch.nn.Linear(hidden_dim, N_OPS)
def forward(self, x):
B = x.shape[0]
h = self.embed(x.reshape(B, -1)) + self.pos_embed
y = self.y_init.expand(B, -1, -1)
z = self.z_init.expand(B, -1, -1)
for _ in range(self.n_outer):
for _ in range(self.n_inner):
z = z + self.block_z(h + y + z)
y = y + self.block_y(y + z)
pooled = self.out_norm(y).mean(dim=1)
return self.head(pooled)
def execute_op(sim: HeapSimulator, op_type: int) -> bool:
"""Execute an operation with automatic parameter selection."""
used_slots = set(sim.slots.keys())
free_slots = [s for s in range(8) if s not in used_slots]
occupied_slots = list(used_slots)
# Find freed-but-still-tracked slots (UAF candidates)
freed_slots = []
for slot in occupied_slots:
addr = sim.slots[slot]
chunk_addr = addr - 16
chunk = sim.chunks.get(chunk_addr)
if chunk and not chunk.allocated:
freed_slots.append(slot)
alloc_slots = [s for s in occupied_slots if s not in freed_slots]
if op_type == OP_MALLOC:
if not free_slots:
return False
slot = random.choice(free_slots)
size = random.choice(SIZES)
return sim.malloc(size, slot=slot) is not None
elif op_type == OP_FREE:
if not alloc_slots:
return False
slot = random.choice(alloc_slots)
return sim.free(user_addr=sim.slots[slot], slot=slot)
elif op_type == OP_WRITE_FREED:
if not freed_slots or len(occupied_slots) < 2:
return False
src = random.choice(freed_slots)
targets = [s for s in occupied_slots if s != src]
if not targets:
return False
tgt = random.choice(targets)
return sim.write_to_freed(sim.slots[src], sim.slots[tgt])
return False
# ============================================================
# EXPERT DEMOS (as operation type sequences)
# ============================================================
TCACHE_POISON_OPS = [
OP_MALLOC, # alloc A
OP_MALLOC, # alloc B
OP_MALLOC, # alloc C (guard)
OP_FREE, # free A -> tcache
OP_FREE, # free B -> tcache
OP_WRITE_FREED, # UAF: corrupt B's fd
OP_MALLOC, # alloc from tcache (gets B)
OP_MALLOC, # alloc from tcache (gets poisoned addr!)
]
def generate_demos(n_variants=200) -> tuple:
"""Generate demo (state, op_type) pairs by running through simulator."""
states = []
labels = []
for _ in range(n_variants):
sim = HeapSimulator()
ops = list(TCACHE_POISON_OPS)
for op in ops:
grid = sim.state_to_grid()
states.append(grid)
labels.append(op)
execute_op(sim, op)
return np.stack(states), np.array(labels, dtype=np.int64)
def train(model, X, y, epochs=200, lr=1e-3, bs=64):
"""Train with cross-entropy on 4-class op prediction."""
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
X_t = torch.from_numpy(X).long()
y_t = torch.from_numpy(y).long()
n = len(X_t)
for ep in range(1, epochs + 1):
model.train()
perm = torch.randperm(n)
total_loss = 0
correct = 0
nb = 0
for i in range(0, n, bs):
idx = perm[i:i+bs]
logits = model(X_t[idx])
loss = F.cross_entropy(logits, y_t[idx])
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y_t[idx]).sum().item()
nb += 1
if ep % 20 == 0 or ep == 1:
acc = correct / n
print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={acc:.3f}")
def evaluate_policy(model, goal="duplicate_alloc", n_trials=200, max_steps=20):
"""Run the learned policy and check success rate."""
model.eval()
n_achieved = 0
lengths = []
for _ in range(n_trials):
sim = HeapSimulator()
for step in range(max_steps):
grid = sim.state_to_grid()
x = torch.from_numpy(grid).long().unsqueeze(0)
with torch.no_grad():
logits = model(x)
op = logits.argmax(1).item()
execute_op(sim, op)
prims = sim.check_primitives()
if prims.get(goal, False):
n_achieved += 1
lengths.append(step + 1)
break
return n_achieved, lengths
def main():
print("=== Generating demos ===")
X, y = generate_demos(300)
print(f"Data: {len(X)} samples")
print(f"Op distribution: malloc={sum(y==0)}, free={sum(y==1)}, "
f"write_freed={sum(y==2)}, noop={sum(y==3)}")
print("\n=== Training (4-class op prediction) ===")
model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3)
params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {params:,}")
train(model, X, y, epochs=200, lr=1e-3)
print("\n=== Evaluating policy ===")
n_achieved, lengths = evaluate_policy(model, n_trials=200, max_steps=20)
print(f"Success rate: {n_achieved}/200 ({n_achieved/2:.0f}%)")
if lengths:
print(f"Avg steps: {np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")
# Also try with beam-like approach: run 10x and take best
print("\n=== Best-of-10 evaluation ===")
n_achieved_bo10 = 0
for trial in range(100):
found = False
for attempt in range(10):
sim = HeapSimulator()
for step in range(20):
grid = sim.state_to_grid()
x = torch.from_numpy(grid).long().unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits / 0.5, dim=1)
op = torch.multinomial(probs, 1).item()
execute_op(sim, op)
if sim.check_primitives().get("duplicate_alloc"):
found = True
break
if found:
break
if found:
n_achieved_bo10 += 1
print(f"Best-of-10 success: {n_achieved_bo10}/100 ({n_achieved_bo10}%)")
if __name__ == "__main__":
main()
|