| |
| """ |
| train_with_demos.py - Train HeapPolicyTRM using expert demonstrations + self-play. |
| |
| Phase 1: Imitation learning on known exploit sequences (how2heap-style) |
| Phase 2: Fine-tune with REINFORCE self-play |
| |
| This bootstraps the policy so it doesn't have to discover exploits from scratch. |
| """ |
|
|
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| sys.path.insert(0, str(Path(__file__).parent.parent / "simulator")) |
| sys.path.insert(0, str(Path(__file__).parent.parent / "model")) |
|
|
| from heap_sim import HeapSimulator |
| from policy import HeapPolicyTRM, encode_action, decode_action, TOTAL_ACTIONS |
| from search import ( |
| train_selfplay, beam_search, execute_action, |
| get_valid_actions, generate_episode, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def demo_tcache_poison() -> list: |
| """Expert sequence for tcache poisoning via UAF.""" |
| return [ |
| {"op": "malloc", "size": 0x40, "slot": 0}, |
| {"op": "malloc", "size": 0x40, "slot": 1}, |
| {"op": "malloc", "size": 0x40, "slot": 2}, |
| {"op": "free", "slot": 0}, |
| {"op": "free", "slot": 1}, |
| {"op": "write_freed", "slot": 1, "target_slot": 2}, |
| {"op": "malloc", "size": 0x40, "slot": 3}, |
| {"op": "malloc", "size": 0x40, "slot": 4}, |
| ] |
|
|
|
|
| def demo_tcache_poison_v2() -> list: |
| """Variant with different sizes.""" |
| return [ |
| {"op": "malloc", "size": 0x30, "slot": 0}, |
| {"op": "malloc", "size": 0x30, "slot": 1}, |
| {"op": "malloc", "size": 0x30, "slot": 2}, |
| {"op": "free", "slot": 1}, |
| {"op": "free", "slot": 0}, |
| {"op": "write_freed", "slot": 0, "target_slot": 2}, |
| {"op": "malloc", "size": 0x30, "slot": 3}, |
| {"op": "malloc", "size": 0x30, "slot": 4}, |
| ] |
|
|
|
|
| def demo_tcache_poison_v3() -> list: |
| """Variant with 0x50 size.""" |
| return [ |
| {"op": "malloc", "size": 0x50, "slot": 0}, |
| {"op": "malloc", "size": 0x50, "slot": 1}, |
| {"op": "free", "slot": 0}, |
| {"op": "free", "slot": 1}, |
| {"op": "write_freed", "slot": 1, "target_slot": 0}, |
| {"op": "malloc", "size": 0x50, "slot": 2}, |
| {"op": "malloc", "size": 0x50, "slot": 3}, |
| ] |
|
|
|
|
| def demo_double_free() -> list: |
| """Double free via UAF (clear tcache key then re-free).""" |
| return [ |
| {"op": "malloc", "size": 0x40, "slot": 0}, |
| {"op": "malloc", "size": 0x40, "slot": 1}, |
| {"op": "free", "slot": 0}, |
| {"op": "free", "slot": 1}, |
| |
| |
| |
| ] |
|
|
|
|
| def generate_demo_variants(base_demo, n_variants=20) -> list: |
| """Generate variants of a demo with different sizes and slot assignments.""" |
| import random |
| variants = [base_demo()] |
|
|
| sizes = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80] |
|
|
| for _ in range(n_variants): |
| demo = base_demo() |
| |
| size = random.choice(sizes) |
| for step in demo: |
| if "size" in step: |
| step["size"] = size |
|
|
| |
| slot_map = {} |
| available = list(range(8)) |
| random.shuffle(available) |
| for step in demo: |
| for key in ["slot", "target_slot"]: |
| if key in step: |
| old = step[key] |
| if old not in slot_map: |
| slot_map[old] = available.pop(0) |
| step[key] = slot_map[old] |
|
|
| variants.append(demo) |
|
|
| return variants |
|
|
|
|
| |
| |
| |
|
|
| def collect_demo_data(demos: list) -> tuple: |
| """Run demos through simulator, collect (state, action) pairs.""" |
| states = [] |
| actions = [] |
| rewards = [] |
|
|
| for demo in demos: |
| sim = HeapSimulator() |
| for i, step in enumerate(demo): |
| grid = sim.state_to_grid() |
| action_idx = encode_action(**step) |
|
|
| states.append(grid) |
| actions.append(action_idx) |
|
|
| success = execute_action(sim, action_idx) |
|
|
| |
| prims = sim.check_primitives() |
| if any(prims.values()): |
| rewards.append(1.0) |
| else: |
| rewards.append(0.0) |
|
|
| return ( |
| np.stack(states), |
| np.array(actions, dtype=np.int64), |
| np.array(rewards, dtype=np.float32), |
| ) |
|
|
|
|
| def train_imitation( |
| model: HeapPolicyTRM, |
| states: np.ndarray, |
| actions: np.ndarray, |
| epochs: int = 100, |
| batch_size: int = 32, |
| lr: float = 1e-3, |
| ) -> float: |
| """Train policy via behavioral cloning on expert demos.""" |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
|
|
| X = torch.from_numpy(states).long() |
| y = torch.from_numpy(actions).long() |
| n = len(X) |
|
|
| model.train() |
| best_acc = 0.0 |
|
|
| for epoch in range(1, epochs + 1): |
| perm = torch.randperm(n) |
| total_loss = 0.0 |
| correct = 0 |
| n_batches = 0 |
|
|
| for i in range(0, n, batch_size): |
| idx = perm[i:i+batch_size] |
| x_batch = X[idx] |
| y_batch = y[idx] |
|
|
| logits, value = model(x_batch) |
| loss = F.cross_entropy(logits, y_batch) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() |
| correct += (logits.argmax(dim=1) == y_batch).sum().item() |
| n_batches += 1 |
|
|
| acc = correct / n |
| avg_loss = total_loss / n_batches |
| if acc > best_acc: |
| best_acc = acc |
|
|
| if epoch % 10 == 0 or epoch == 1: |
| print(f" Epoch {epoch:3d} | loss={avg_loss:.4f} | acc={acc:.3f} " |
| f"| best_acc={best_acc:.3f}") |
|
|
| return best_acc |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=== Phase 1: Generating expert demonstrations ===") |
| all_demos = [] |
| all_demos += generate_demo_variants(demo_tcache_poison, 30) |
| all_demos += generate_demo_variants(demo_tcache_poison_v2, 30) |
| all_demos += generate_demo_variants(demo_tcache_poison_v3, 30) |
| print(f"Generated {len(all_demos)} demo sequences") |
|
|
| states, actions, rewards = collect_demo_data(all_demos) |
| print(f"Collected {len(states)} (state, action) pairs") |
| print(f"Achieved exploit in {(rewards > 0).sum()} steps") |
|
|
| print("\n=== Phase 2: Imitation learning ===") |
| model = HeapPolicyTRM(hidden_dim=128, n_outer=2, n_inner=3) |
| print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| best_acc = train_imitation(model, states, actions, epochs=100, lr=1e-3) |
| print(f"Best imitation accuracy: {best_acc:.3f}") |
|
|
| |
| print("\n=== Phase 3: Testing learned policy ===") |
| n_test = 50 |
| n_achieved = 0 |
| lengths = [] |
|
|
| for i in range(n_test): |
| episode = generate_episode(model, goal="tcache_poison", |
| max_steps=20, temperature=0.3) |
| if episode["achieved"]: |
| n_achieved += 1 |
| lengths.append(episode["n_steps"]) |
|
|
| print(f"Policy achieves tcache_poison: {n_achieved}/{n_test} " |
| f"({n_achieved/n_test*100:.0f}%)") |
| if lengths: |
| print(f"Average steps when successful: {np.mean(lengths):.1f}") |
|
|
| |
| print("\n=== Phase 4: Beam search ===") |
| sim = HeapSimulator() |
| result = beam_search(model, sim, goal="tcache_poison", |
| beam_width=16, max_steps=20, temperature=0.3) |
|
|
| if result and result.primitives.get("tcache_poison"): |
| print(f"Beam search found exploit in {len(result.actions)} steps:") |
| for i, a in enumerate(result.actions): |
| print(f" Step {i}: {decode_action(a)}") |
| else: |
| print("Beam search did not find exploit") |
|
|
| |
| print("\n=== Phase 5: Self-play fine-tuning ===") |
| stats = train_selfplay( |
| model, |
| goal="tcache_poison", |
| n_episodes=200, |
| max_steps=20, |
| lr=1e-4, |
| print_every=50, |
| ) |
|
|
| |
| print("\n=== Final evaluation ===") |
| n_achieved = 0 |
| for i in range(100): |
| episode = generate_episode(model, goal="tcache_poison", |
| max_steps=20, temperature=0.1) |
| if episode["achieved"]: |
| n_achieved += 1 |
|
|
| print(f"Final success rate: {n_achieved}/100 ({n_achieved}%)") |
|
|
| |
| output_dir = Path(__file__).parent.parent / "agent" / "checkpoints" |
| output_dir.mkdir(parents=True, exist_ok=True) |
| torch.save(model.state_dict(), output_dir / "policy_model.pt") |
| print(f"Model saved to {output_dir / 'policy_model.pt'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|