File size: 8,233 Bytes
22374d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""
simple_agent.py - Simplified action-prediction agent.

Instead of predicting exact (op, size, slot) in one shot from 128 actions,
decompose into:
  Step 1: Predict operation type (4: malloc, free, write_freed, noop)
  Step 2: For the chosen op, pick parameters using simple heuristics

This makes imitation learning tractable (4-class instead of 128-class).
The model learns the HIGH-LEVEL STRATEGY (when to alloc vs free vs UAF-write),
not the low-level details (which slot, which size).
"""

import sys
import torch
import torch.nn.functional as F
import numpy as np
import random
import copy
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent / "simulator"))
sys.path.insert(0, str(Path(__file__).parent.parent / "model"))

from heap_sim import HeapSimulator
from trm_heap import RMSNorm, SwiGLU, RecursionBlock

# 4 operation types
OP_MALLOC = 0
OP_FREE = 1
OP_WRITE_FREED = 2
OP_NOOP = 3
N_OPS = 4

SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80]


class SimpleHeapTRM(torch.nn.Module):
    """TRM that predicts operation type (4 classes)."""

    def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
                 n_outer=2, n_inner=3):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, hidden_dim)
        self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.block_z = RecursionBlock(hidden_dim)
        self.block_y = RecursionBlock(hidden_dim)
        self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.out_norm = RMSNorm(hidden_dim)
        self.n_outer = n_outer
        self.n_inner = n_inner

        # 4-class output
        self.head = torch.nn.Linear(hidden_dim, N_OPS)

    def forward(self, x):
        B = x.shape[0]
        h = self.embed(x.reshape(B, -1)) + self.pos_embed
        y = self.y_init.expand(B, -1, -1)
        z = self.z_init.expand(B, -1, -1)
        for _ in range(self.n_outer):
            for _ in range(self.n_inner):
                z = z + self.block_z(h + y + z)
                y = y + self.block_y(y + z)
        pooled = self.out_norm(y).mean(dim=1)
        return self.head(pooled)


def execute_op(sim: HeapSimulator, op_type: int) -> bool:
    """Execute an operation with automatic parameter selection."""
    used_slots = set(sim.slots.keys())
    free_slots = [s for s in range(8) if s not in used_slots]
    occupied_slots = list(used_slots)

    # Find freed-but-still-tracked slots (UAF candidates)
    freed_slots = []
    for slot in occupied_slots:
        addr = sim.slots[slot]
        chunk_addr = addr - 16
        chunk = sim.chunks.get(chunk_addr)
        if chunk and not chunk.allocated:
            freed_slots.append(slot)

    alloc_slots = [s for s in occupied_slots if s not in freed_slots]

    if op_type == OP_MALLOC:
        if not free_slots:
            return False
        slot = random.choice(free_slots)
        size = random.choice(SIZES)
        return sim.malloc(size, slot=slot) is not None

    elif op_type == OP_FREE:
        if not alloc_slots:
            return False
        slot = random.choice(alloc_slots)
        return sim.free(user_addr=sim.slots[slot], slot=slot)

    elif op_type == OP_WRITE_FREED:
        if not freed_slots or len(occupied_slots) < 2:
            return False
        src = random.choice(freed_slots)
        targets = [s for s in occupied_slots if s != src]
        if not targets:
            return False
        tgt = random.choice(targets)
        return sim.write_to_freed(sim.slots[src], sim.slots[tgt])

    return False


# ============================================================
# EXPERT DEMOS (as operation type sequences)
# ============================================================

TCACHE_POISON_OPS = [
    OP_MALLOC,       # alloc A
    OP_MALLOC,       # alloc B
    OP_MALLOC,       # alloc C (guard)
    OP_FREE,         # free A -> tcache
    OP_FREE,         # free B -> tcache
    OP_WRITE_FREED,  # UAF: corrupt B's fd
    OP_MALLOC,       # alloc from tcache (gets B)
    OP_MALLOC,       # alloc from tcache (gets poisoned addr!)
]


def generate_demos(n_variants=200) -> tuple:
    """Generate demo (state, op_type) pairs by running through simulator."""
    states = []
    labels = []

    for _ in range(n_variants):
        sim = HeapSimulator()
        ops = list(TCACHE_POISON_OPS)

        for op in ops:
            grid = sim.state_to_grid()
            states.append(grid)
            labels.append(op)
            execute_op(sim, op)

    return np.stack(states), np.array(labels, dtype=np.int64)


def train(model, X, y, epochs=200, lr=1e-3, bs=64):
    """Train with cross-entropy on 4-class op prediction."""
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    X_t = torch.from_numpy(X).long()
    y_t = torch.from_numpy(y).long()
    n = len(X_t)

    for ep in range(1, epochs + 1):
        model.train()
        perm = torch.randperm(n)
        total_loss = 0
        correct = 0
        nb = 0

        for i in range(0, n, bs):
            idx = perm[i:i+bs]
            logits = model(X_t[idx])
            loss = F.cross_entropy(logits, y_t[idx])
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += loss.item()
            correct += (logits.argmax(1) == y_t[idx]).sum().item()
            nb += 1

        if ep % 20 == 0 or ep == 1:
            acc = correct / n
            print(f"  Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={acc:.3f}")


def evaluate_policy(model, goal="duplicate_alloc", n_trials=200, max_steps=20):
    """Run the learned policy and check success rate."""
    model.eval()
    n_achieved = 0
    lengths = []

    for _ in range(n_trials):
        sim = HeapSimulator()
        for step in range(max_steps):
            grid = sim.state_to_grid()
            x = torch.from_numpy(grid).long().unsqueeze(0)
            with torch.no_grad():
                logits = model(x)
                op = logits.argmax(1).item()

            execute_op(sim, op)

            prims = sim.check_primitives()
            if prims.get(goal, False):
                n_achieved += 1
                lengths.append(step + 1)
                break

    return n_achieved, lengths


def main():
    print("=== Generating demos ===")
    X, y = generate_demos(300)
    print(f"Data: {len(X)} samples")
    print(f"Op distribution: malloc={sum(y==0)}, free={sum(y==1)}, "
          f"write_freed={sum(y==2)}, noop={sum(y==3)}")

    print("\n=== Training (4-class op prediction) ===")
    model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3)
    params = sum(p.numel() for p in model.parameters())
    print(f"Parameters: {params:,}")
    train(model, X, y, epochs=200, lr=1e-3)

    print("\n=== Evaluating policy ===")
    n_achieved, lengths = evaluate_policy(model, n_trials=200, max_steps=20)
    print(f"Success rate: {n_achieved}/200 ({n_achieved/2:.0f}%)")
    if lengths:
        print(f"Avg steps: {np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")

    # Also try with beam-like approach: run 10x and take best
    print("\n=== Best-of-10 evaluation ===")
    n_achieved_bo10 = 0
    for trial in range(100):
        found = False
        for attempt in range(10):
            sim = HeapSimulator()
            for step in range(20):
                grid = sim.state_to_grid()
                x = torch.from_numpy(grid).long().unsqueeze(0)
                with torch.no_grad():
                    logits = model(x)
                    probs = F.softmax(logits / 0.5, dim=1)
                    op = torch.multinomial(probs, 1).item()
                execute_op(sim, op)
                if sim.check_primitives().get("duplicate_alloc"):
                    found = True
                    break
            if found:
                break
        if found:
            n_achieved_bo10 += 1

    print(f"Best-of-10 success: {n_achieved_bo10}/100 ({n_achieved_bo10}%)")


if __name__ == "__main__":
    main()