| |
| """ |
| real_binary_bridge.py - Connect TRM agent to a real binary via LD_PRELOAD harness. |
| |
| Architecture: |
| 1. Launch vuln_heap binary with heapgrid_harness.so (LD_PRELOAD) |
| 2. Agent reads heap state from harness dump file after each command |
| 3. Agent predicts next operation type via TRM |
| 4. Translates to menu commands and sends via stdin pipe |
| 5. Repeats until exploit primitive detected or max steps |
| |
| This is the simulator-to-reality transfer test. |
| """ |
|
|
| import subprocess |
| import os |
| import sys |
| import json |
| import time |
| import random |
| import tempfile |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT / "agent")) |
| sys.path.insert(0, str(ROOT / "simulator")) |
| sys.path.insert(0, str(ROOT / "model")) |
| sys.path.insert(0, str(ROOT / "dataset")) |
|
|
| from simple_agent import SimpleHeapTRM, OP_MALLOC, OP_FREE, OP_WRITE_FREED, SIZES |
| from dataset_gen import state_to_grid, load_dump |
|
|
| BINARY = ROOT / "ctf" / "vuln_heap" |
| HARNESS = ROOT / "harness" / "heapgrid_harness.so" |
|
|
|
|
| class RealBinaryEnv: |
| """Drives a real binary with the heap harness, providing grid observations.""" |
|
|
| def __init__(self, binary=BINARY, harness=HARNESS): |
| self.binary = str(binary) |
| self.harness = str(harness) |
| self.proc = None |
| self.dump_file = None |
| self.slots = {} |
| self.slot_sizes = {} |
| self.commands_sent = [] |
| self._last_dump_lines = 0 |
|
|
| def start(self): |
| """Launch the binary with harness.""" |
| self.dump_file = tempfile.NamedTemporaryFile( |
| suffix=".jsonl", delete=False, mode="w" |
| ) |
| self.dump_path = self.dump_file.name |
| self.dump_file.close() |
|
|
| env = os.environ.copy() |
| env["LD_PRELOAD"] = self.harness |
| env["HEAPGRID_OUT"] = self.dump_path |
|
|
| self.proc = subprocess.Popen( |
| [self.binary], |
| stdin=subprocess.PIPE, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| env=env, |
| ) |
| self.slots = {} |
| self.slot_sizes = {} |
| self.commands_sent = [] |
| self._last_dump_lines = 0 |
|
|
| def send_command(self, cmd: str): |
| """Send a menu command to the binary.""" |
| self.proc.stdin.write((cmd + "\n").encode()) |
| self.proc.stdin.flush() |
| self.commands_sent.append(cmd) |
| time.sleep(0.01) |
|
|
| def do_malloc(self, slot: int, size: int): |
| """Menu option 1: allocate note.""" |
| self.send_command(f"1 {slot} {size}") |
| self.slots[slot] = True |
| self.slot_sizes[slot] = size |
|
|
| def do_free(self, slot: int): |
| """Menu option 4: delete note.""" |
| self.send_command(f"4 {slot}") |
| self.slots[slot] = False |
|
|
| def do_edit_uaf(self, slot: int, data_hex: str): |
| """Menu option 2: edit note (works on freed chunks = UAF).""" |
| self.send_command(f"2 {slot} {data_hex}") |
|
|
| def do_show(self, slot: int): |
| """Menu option 3: show note.""" |
| self.send_command(f"3 {slot}") |
|
|
| def do_exit(self): |
| """Menu option 5: exit.""" |
| self.send_command("5") |
|
|
| def get_heap_state(self) -> dict: |
| """Read the latest heap state from the dump file.""" |
| try: |
| with open(self.dump_path, "r") as f: |
| lines = f.readlines() |
| if not lines: |
| return None |
| |
| last_line = lines[-1].strip() |
| if last_line: |
| return json.loads(last_line) |
| except Exception: |
| pass |
| return None |
|
|
| def get_grid(self) -> np.ndarray: |
| """Get current heap state as a 32x16 grid.""" |
| state = self.get_heap_state() |
| if state is None: |
| return np.zeros((32, 16), dtype=np.int64) |
| return state_to_grid(state) |
|
|
| def get_all_states(self) -> list: |
| """Read all heap states from dump.""" |
| try: |
| return load_dump(Path(self.dump_path)) |
| except Exception: |
| return [] |
|
|
| def check_duplicate_alloc(self) -> bool: |
| """Check if any two slots point to the same address (from dump data).""" |
| state = self.get_heap_state() |
| if state is None: |
| return False |
|
|
| |
| |
| |
| chunks = state.get("chunks", []) |
| addrs = [] |
| for c in chunks: |
| if c.get("state") == 1: |
| addrs.append(c.get("addr")) |
| |
| return len(addrs) != len(set(addrs)) |
|
|
| def stop(self): |
| """Clean up.""" |
| if self.proc: |
| try: |
| self.proc.stdin.close() |
| self.proc.wait(timeout=2) |
| except Exception: |
| self.proc.kill() |
| if self.dump_path and os.path.exists(self.dump_path): |
| os.unlink(self.dump_path) |
|
|
|
|
| def run_agent_on_real_binary( |
| model: SimpleHeapTRM, |
| max_steps: int = 20, |
| temperature: float = 0.3, |
| verbose: bool = True, |
| ) -> dict: |
| """ |
| Run the TRM agent against the real vuln_heap binary. |
| |
| Returns dict with success info and the command sequence used. |
| """ |
| env = RealBinaryEnv() |
| env.start() |
|
|
| |
| time.sleep(0.05) |
|
|
| result = { |
| "achieved": False, |
| "steps": 0, |
| "commands": [], |
| "op_sequence": [], |
| } |
|
|
| model.eval() |
|
|
| for step in range(max_steps): |
| |
| grid = env.get_grid() |
|
|
| |
| x = torch.from_numpy(grid).long().unsqueeze(0) |
| with torch.no_grad(): |
| logits = model(x) |
| if temperature == 0: |
| op = logits.argmax(1).item() |
| else: |
| probs = F.softmax(logits / temperature, dim=1) |
| op = torch.multinomial(probs, 1).item() |
|
|
| |
| success = False |
|
|
| if op == OP_MALLOC: |
| |
| free_slots = [s for s in range(16) if s not in env.slots or not env.slots.get(s)] |
| if free_slots: |
| slot = random.choice(free_slots[:8]) |
| size = random.choice(SIZES) |
| env.do_malloc(slot, size) |
| success = True |
| if verbose: |
| print(f" Step {step}: MALLOC slot={slot} size={hex(size)}") |
|
|
| elif op == OP_FREE: |
| |
| alloc_slots = [s for s, v in env.slots.items() if v] |
| if alloc_slots: |
| slot = random.choice(alloc_slots) |
| env.do_free(slot) |
| success = True |
| if verbose: |
| print(f" Step {step}: FREE slot={slot}") |
|
|
| elif op == OP_WRITE_FREED: |
| |
| freed_slots = [s for s, v in env.slots.items() if not v] |
| occupied = list(env.slots.keys()) |
| if freed_slots and len(occupied) >= 2: |
| src = random.choice(freed_slots) |
| targets = [s for s in occupied if s != src] |
| if targets: |
| tgt = random.choice(targets) |
| |
| |
| |
| env.do_edit_uaf(src, "41" * 8) |
| success = True |
| if verbose: |
| print(f" Step {step}: WRITE_FREED slot={src} (UAF edit)") |
|
|
| result["op_sequence"].append(["malloc", "free", "write_freed", "noop"][op]) |
| result["commands"] = env.commands_sent.copy() |
|
|
| if not success: |
| if verbose: |
| print(f" Step {step}: {['MALLOC','FREE','WRITE_FREED','NOOP'][op]} - skipped (no valid target)") |
| continue |
|
|
| time.sleep(0.02) |
|
|
| |
| if env.check_duplicate_alloc(): |
| result["achieved"] = True |
| result["steps"] = step + 1 |
| if verbose: |
| print(f" ** EXPLOIT PRIMITIVE ACHIEVED at step {step + 1}! **") |
| break |
|
|
| env.do_exit() |
| env.stop() |
|
|
| if not result["achieved"]: |
| result["steps"] = max_steps |
|
|
| return result |
|
|
|
|
| def main(): |
| |
| print("=== Loading trained model ===") |
| model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3) |
|
|
| |
| from simple_agent import generate_demos, train |
| X, y = generate_demos(300) |
| print(f"Training on {len(X)} demo samples...") |
| train(model, X, y, epochs=100, lr=1e-3) |
|
|
| |
| print("\n=== Testing on real vuln_heap binary ===") |
|
|
| n_trials = 50 |
| n_achieved = 0 |
| all_steps = [] |
|
|
| for trial in range(n_trials): |
| print(f"\n--- Trial {trial + 1}/{n_trials} ---") |
| result = run_agent_on_real_binary( |
| model, max_steps=20, temperature=0.3, verbose=True) |
|
|
| if result["achieved"]: |
| n_achieved += 1 |
| all_steps.append(result["steps"]) |
| print(f" SUCCESS in {result['steps']} steps") |
| print(f" Op sequence: {' -> '.join(result['op_sequence'][:result['steps']])}") |
| else: |
| print(f" FAILED after {result['steps']} steps") |
|
|
| print(f"\n{'='*60}") |
| print(f"REAL BINARY RESULTS") |
| print(f"{'='*60}") |
| print(f"Success rate: {n_achieved}/{n_trials} ({n_achieved/n_trials*100:.0f}%)") |
| if all_steps: |
| print(f"Avg steps when successful: {np.mean(all_steps):.1f}") |
| print(f"Min steps: {min(all_steps)}, Max steps: {max(all_steps)}") |
|
|
| |
| print(f"\n=== Best-of-10 evaluation (20 trials) ===") |
| n_bo10 = 0 |
| for trial in range(20): |
| found = False |
| for attempt in range(10): |
| result = run_agent_on_real_binary( |
| model, max_steps=20, temperature=0.5, verbose=False) |
| if result["achieved"]: |
| found = True |
| break |
| if found: |
| n_bo10 += 1 |
| print(f" Trial {trial+1}: {'FOUND' if found else 'MISS'}") |
|
|
| print(f"\nBest-of-10 success: {n_bo10}/20 ({n_bo10/20*100:.0f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|