heap-trm / agent /train_with_demos.py
amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
#!/usr/bin/env python3
"""
train_with_demos.py - Train HeapPolicyTRM using expert demonstrations + self-play.
Phase 1: Imitation learning on known exploit sequences (how2heap-style)
Phase 2: Fine-tune with REINFORCE self-play
This bootstraps the policy so it doesn't have to discover exploits from scratch.
"""
import sys
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent.parent / "simulator"))
sys.path.insert(0, str(Path(__file__).parent.parent / "model"))
from heap_sim import HeapSimulator
from policy import HeapPolicyTRM, encode_action, decode_action, TOTAL_ACTIONS
from search import (
train_selfplay, beam_search, execute_action,
get_valid_actions, generate_episode,
)
# ============================================================
# EXPERT DEMONSTRATIONS
# ============================================================
def demo_tcache_poison() -> list:
"""Expert sequence for tcache poisoning via UAF."""
return [
{"op": "malloc", "size": 0x40, "slot": 0},
{"op": "malloc", "size": 0x40, "slot": 1},
{"op": "malloc", "size": 0x40, "slot": 2}, # guard
{"op": "free", "slot": 0},
{"op": "free", "slot": 1},
{"op": "write_freed", "slot": 1, "target_slot": 2}, # poison fd
{"op": "malloc", "size": 0x40, "slot": 3}, # gets slot 1's chunk
{"op": "malloc", "size": 0x40, "slot": 4}, # gets poisoned addr
]
def demo_tcache_poison_v2() -> list:
"""Variant with different sizes."""
return [
{"op": "malloc", "size": 0x30, "slot": 0},
{"op": "malloc", "size": 0x30, "slot": 1},
{"op": "malloc", "size": 0x30, "slot": 2},
{"op": "free", "slot": 1},
{"op": "free", "slot": 0},
{"op": "write_freed", "slot": 0, "target_slot": 2},
{"op": "malloc", "size": 0x30, "slot": 3},
{"op": "malloc", "size": 0x30, "slot": 4},
]
def demo_tcache_poison_v3() -> list:
"""Variant with 0x50 size."""
return [
{"op": "malloc", "size": 0x50, "slot": 0},
{"op": "malloc", "size": 0x50, "slot": 1},
{"op": "free", "slot": 0},
{"op": "free", "slot": 1},
{"op": "write_freed", "slot": 1, "target_slot": 0},
{"op": "malloc", "size": 0x50, "slot": 2},
{"op": "malloc", "size": 0x50, "slot": 3},
]
def demo_double_free() -> list:
"""Double free via UAF (clear tcache key then re-free)."""
return [
{"op": "malloc", "size": 0x40, "slot": 0},
{"op": "malloc", "size": 0x40, "slot": 1},
{"op": "free", "slot": 0},
{"op": "free", "slot": 1},
# In a real exploit, we'd clear tcache key via UAF write
# For the simulator, we can directly double-free since we
# don't model the key check (pre-2.29 behavior)
]
def generate_demo_variants(base_demo, n_variants=20) -> list:
"""Generate variants of a demo with different sizes and slot assignments."""
import random
variants = [base_demo()]
sizes = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80]
for _ in range(n_variants):
demo = base_demo()
# Randomly remap sizes
size = random.choice(sizes)
for step in demo:
if "size" in step:
step["size"] = size
# Randomly remap slots (preserving relationships)
slot_map = {}
available = list(range(8))
random.shuffle(available)
for step in demo:
for key in ["slot", "target_slot"]:
if key in step:
old = step[key]
if old not in slot_map:
slot_map[old] = available.pop(0)
step[key] = slot_map[old]
variants.append(demo)
return variants
# ============================================================
# IMITATION LEARNING
# ============================================================
def collect_demo_data(demos: list) -> tuple:
"""Run demos through simulator, collect (state, action) pairs."""
states = []
actions = []
rewards = []
for demo in demos:
sim = HeapSimulator()
for i, step in enumerate(demo):
grid = sim.state_to_grid()
action_idx = encode_action(**step)
states.append(grid)
actions.append(action_idx)
success = execute_action(sim, action_idx)
# Check if we achieved anything
prims = sim.check_primitives()
if any(prims.values()):
rewards.append(1.0)
else:
rewards.append(0.0)
return (
np.stack(states),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
)
def train_imitation(
model: HeapPolicyTRM,
states: np.ndarray,
actions: np.ndarray,
epochs: int = 100,
batch_size: int = 32,
lr: float = 1e-3,
) -> float:
"""Train policy via behavioral cloning on expert demos."""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
X = torch.from_numpy(states).long()
y = torch.from_numpy(actions).long()
n = len(X)
model.train()
best_acc = 0.0
for epoch in range(1, epochs + 1):
perm = torch.randperm(n)
total_loss = 0.0
correct = 0
n_batches = 0
for i in range(0, n, batch_size):
idx = perm[i:i+batch_size]
x_batch = X[idx]
y_batch = y[idx]
logits, value = model(x_batch)
loss = F.cross_entropy(logits, y_batch)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=1) == y_batch).sum().item()
n_batches += 1
acc = correct / n
avg_loss = total_loss / n_batches
if acc > best_acc:
best_acc = acc
if epoch % 10 == 0 or epoch == 1:
print(f" Epoch {epoch:3d} | loss={avg_loss:.4f} | acc={acc:.3f} "
f"| best_acc={best_acc:.3f}")
return best_acc
# ============================================================
# MAIN
# ============================================================
def main():
print("=== Phase 1: Generating expert demonstrations ===")
all_demos = []
all_demos += generate_demo_variants(demo_tcache_poison, 30)
all_demos += generate_demo_variants(demo_tcache_poison_v2, 30)
all_demos += generate_demo_variants(demo_tcache_poison_v3, 30)
print(f"Generated {len(all_demos)} demo sequences")
states, actions, rewards = collect_demo_data(all_demos)
print(f"Collected {len(states)} (state, action) pairs")
print(f"Achieved exploit in {(rewards > 0).sum()} steps")
print("\n=== Phase 2: Imitation learning ===")
model = HeapPolicyTRM(hidden_dim=128, n_outer=2, n_inner=3)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
best_acc = train_imitation(model, states, actions, epochs=100, lr=1e-3)
print(f"Best imitation accuracy: {best_acc:.3f}")
# Test: can the model reproduce exploits?
print("\n=== Phase 3: Testing learned policy ===")
n_test = 50
n_achieved = 0
lengths = []
for i in range(n_test):
episode = generate_episode(model, goal="tcache_poison",
max_steps=20, temperature=0.3)
if episode["achieved"]:
n_achieved += 1
lengths.append(episode["n_steps"])
print(f"Policy achieves tcache_poison: {n_achieved}/{n_test} "
f"({n_achieved/n_test*100:.0f}%)")
if lengths:
print(f"Average steps when successful: {np.mean(lengths):.1f}")
# Phase 4: Beam search
print("\n=== Phase 4: Beam search ===")
sim = HeapSimulator()
result = beam_search(model, sim, goal="tcache_poison",
beam_width=16, max_steps=20, temperature=0.3)
if result and result.primitives.get("tcache_poison"):
print(f"Beam search found exploit in {len(result.actions)} steps:")
for i, a in enumerate(result.actions):
print(f" Step {i}: {decode_action(a)}")
else:
print("Beam search did not find exploit")
# Phase 5: Fine-tune with self-play
print("\n=== Phase 5: Self-play fine-tuning ===")
stats = train_selfplay(
model,
goal="tcache_poison",
n_episodes=200,
max_steps=20,
lr=1e-4,
print_every=50,
)
# Final test
print("\n=== Final evaluation ===")
n_achieved = 0
for i in range(100):
episode = generate_episode(model, goal="tcache_poison",
max_steps=20, temperature=0.1)
if episode["achieved"]:
n_achieved += 1
print(f"Final success rate: {n_achieved}/100 ({n_achieved}%)")
# Save model
output_dir = Path(__file__).parent.parent / "agent" / "checkpoints"
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), output_dir / "policy_model.pt")
print(f"Model saved to {output_dir / 'policy_model.pt'}")
if __name__ == "__main__":
main()