heap-trm / agent /train_universal.py
amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
#!/usr/bin/env python3
"""
train_universal.py - Train with allocator-agnostic grid on GPU.
Uses UniversalGridEncoder (relationship-based features) and trains on RTX 4090.
"""
import sys
import os
import json
import subprocess
import tempfile
import time
import random
import re
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from collections import deque
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "agent"))
sys.path.insert(0, str(ROOT / "model"))
from simple_agent import SimpleHeapTRM, OP_MALLOC, OP_FREE, OP_WRITE_FREED, SIZES
from universal_grid import UniversalGridEncoder
BINARY = ROOT / "ctf" / "vuln_heap"
HARNESS = ROOT / "harness" / "heapgrid_harness.so"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def run_and_dump(commands):
dump_path = tempfile.mktemp(suffix=".jsonl")
env = os.environ.copy()
env["LD_PRELOAD"] = str(HARNESS)
env["HEAPGRID_OUT"] = dump_path
input_str = "\n".join(commands) + "\n5\n"
subprocess.run([str(BINARY)], input=input_str.encode(),
env=env, capture_output=True, timeout=10)
states = []
if os.path.exists(dump_path):
with open(dump_path) as f:
for line in f:
if line.strip():
states.append(json.loads(line.strip()))
os.unlink(dump_path)
return states
def gen_exploit(n_pre=None):
size = random.choice(SIZES)
slots = random.sample(range(16), 8)
if n_pre is None:
n_pre = random.randint(2, 6)
commands = []
labels = []
alloc_slots = []
for i in range(n_pre):
commands.append(f"1 {slots[i]} {size}")
labels.append(OP_MALLOC)
alloc_slots.append(slots[i])
free_targets = random.sample(alloc_slots, 2)
for s in free_targets:
commands.append(f"4 {s}")
labels.append(OP_FREE)
commands.append(f"2 {free_targets[1]} {'41' * 8}")
labels.append(OP_WRITE_FREED)
drain = [s for s in slots if s not in alloc_slots][:2]
for s in drain:
commands.append(f"1 {s} {size}")
labels.append(OP_MALLOC)
return commands, labels
def gen_benign():
n_ops = random.randint(4, 15)
commands, labels = [], []
allocated = {}
for _ in range(n_ops):
if not allocated or random.random() < 0.55:
free_slots = [s for s in range(16) if s not in allocated]
if not free_slots:
break
s = random.choice(free_slots)
sz = random.choice(SIZES)
commands.append(f"1 {s} {sz}")
labels.append(OP_MALLOC)
allocated[s] = sz
else:
s = random.choice(list(allocated.keys()))
commands.append(f"4 {s}")
labels.append(OP_FREE)
del allocated[s]
return commands, labels
def collect_data(n_exploit=500, n_benign=250):
all_grids, all_labels = [], []
print(f" Generating {n_exploit} exploit sequences...")
for i in range(n_exploit):
commands, op_labels = gen_exploit()
states = run_and_dump(commands)
if len(states) < 3:
continue
encoder = UniversalGridEncoder()
n_free_seen = 0
write_done = False
for j, state in enumerate(states):
op = state.get("operation", "malloc")
if "free" in op.lower():
n_free_seen += 1
chunks = state.get("chunks", [])
n_alloc = sum(1 for c in chunks if c.get("state") == 1)
n_freed = sum(1 for c in chunks if c.get("state") == 2)
if n_free_seen == 0:
next_label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC
elif n_free_seen >= 2 and not write_done:
next_label = OP_WRITE_FREED
write_done = True
else:
next_label = OP_MALLOC
grid = encoder.encode(state)
all_grids.append(grid)
all_labels.append(next_label)
actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC
encoder.record_action(actual_op, state.get("target_size", 0))
print(f" Generating {n_benign} benign sequences...")
for i in range(n_benign):
commands, _ = gen_benign()
states = run_and_dump(commands)
encoder = UniversalGridEncoder()
for j, state in enumerate(states):
op = state.get("operation", "malloc")
if j + 1 < len(states):
next_op = states[j + 1].get("operation", "malloc")
next_label = OP_FREE if "free" in next_op.lower() else OP_MALLOC
else:
next_label = OP_MALLOC
grid = encoder.encode(state)
all_grids.append(grid)
all_labels.append(next_label)
actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC
encoder.record_action(actual_op, state.get("target_size", 0))
return np.stack(all_grids), np.array(all_labels, dtype=np.int64)
def train_gpu(model, X, y, epochs=200, lr=1e-3, bs=128):
model.to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
X_t = torch.from_numpy(X).long().to(DEVICE)
y_t = torch.from_numpy(y).long().to(DEVICE)
n = len(X_t)
for ep in range(1, epochs + 1):
model.train()
perm = torch.randperm(n, device=DEVICE)
total_loss = 0
correct = 0
nb = 0
for i in range(0, n, bs):
idx = perm[i:i+bs]
logits = model(X_t[idx])
loss = F.cross_entropy(logits, y_t[idx])
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y_t[idx]).sum().item()
nb += 1
sched.step()
if ep % 20 == 0 or ep == 1:
print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}")
def evaluate_live(model, n_trials=50, max_steps=20, temperature=0.3):
model.eval()
model.to(DEVICE)
n_uaf = 0
n_correct_seq = 0
seqs = []
for trial in range(n_trials):
dump_path = tempfile.mktemp(suffix=".jsonl")
env = os.environ.copy()
env["LD_PRELOAD"] = str(HARNESS)
env["HEAPGRID_OUT"] = dump_path
proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
encoder = UniversalGridEncoder()
slots = {}
freed = set()
ops = []
did_write = False
n_frees_since_alloc = 0
for step in range(max_steps):
time.sleep(0.01)
state = {"chunks": []}
try:
with open(dump_path) as f:
lines = f.readlines()
if lines:
state = json.loads(lines[-1].strip())
except:
pass
grid = encoder.encode(state)
# Hybrid: rule triggers W after 2+ frees
if freed and not did_write and n_frees_since_alloc >= 2:
op = OP_WRITE_FREED
else:
x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x)
if temperature == 0:
op = logits.argmax(1).item()
else:
probs = F.softmax(logits / temperature, dim=1)
op = torch.multinomial(probs, 1).item()
free_slots = [s for s in range(8) if s not in slots]
alloc_slots = [s for s, v in slots.items() if v and s not in freed]
cmd = None
size = 0
if op == OP_MALLOC and free_slots:
s = random.choice(free_slots)
size = random.choice(SIZES)
cmd = f"1 {s} {size}"
slots[s] = True
ops.append("M")
n_frees_since_alloc = 0
elif op == OP_FREE and alloc_slots:
s = random.choice(alloc_slots)
cmd = f"4 {s}"
slots[s] = False
freed.add(s)
ops.append("F")
n_frees_since_alloc += 1
elif op == OP_WRITE_FREED and freed:
s = random.choice(list(freed))
cmd = f"2 {s} {'41' * 8}"
did_write = True
ops.append("W")
else:
ops.append("x")
if cmd:
proc.stdin.write((cmd + "\n").encode())
proc.stdin.flush()
encoder.record_action(op, size)
try:
proc.stdin.write(b"5\n"); proc.stdin.flush()
proc.wait(timeout=2)
except:
proc.kill()
if did_write:
try:
with open(dump_path) as f:
for line in f:
state = json.loads(line.strip())
for c in state.get("chunks", []):
if c.get("fd", 0) == 0x4141414141414141:
n_uaf += 1
break
else:
continue
break
except:
pass
seq = "".join(ops)
if re.match(r"M+F+WM*", seq.replace("x", "")):
n_correct_seq += 1
seqs.append(seq)
if os.path.exists(dump_path):
os.unlink(dump_path)
return n_uaf, n_correct_seq, seqs
def main():
print(f"Device: {DEVICE}")
print("\n=== Collecting data ===")
X, y = collect_data(n_exploit=500, n_benign=250)
print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} W={sum(y==2)}")
print("\n=== Training on GPU ===")
model = SimpleHeapTRM(hidden_dim=128, n_outer=2, n_inner=3)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128)
print("\n=== Live evaluation (hybrid, 50 trials) ===")
n_uaf, n_seq, seqs = evaluate_live(model, n_trials=50, temperature=0.3)
print(f"UAF writes: {n_uaf}/50 ({n_uaf*2}%)")
print(f"Correct sequences: {n_seq}/50 ({n_seq*2}%)")
print("Samples:")
for s in seqs[:10]:
print(f" {s}")
print("\n=== Best-of-5 (30 trials) ===")
n_bo5 = 0
for trial in range(30):
found = False
for attempt in range(5):
nw, _, _ = evaluate_live(model, n_trials=1, temperature=0.5)
if nw > 0:
found = True
break
if found:
n_bo5 += 1
print(f"Best-of-5: {n_bo5}/30 ({n_bo5/30*100:.0f}%)")
if __name__ == "__main__":
main()