heap-trm / agent /multi_technique.py
amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
#!/usr/bin/env python3
"""
multi_technique.py - Train one model on multiple exploit techniques.
Techniques:
1. Tcache poison (UAF fd corruption): M+ F F W_UAF M+
2. Off-by-one overlap (null byte overflow): M M M W_OVF F M
3. Coalesce abuse (free adjacent, alloc large): M M M F F M_large
Action space (4 ops):
0: MALLOC
1: FREE
2: WRITE_UAF (edit freed chunk - corrupt fd pointer)
3: WRITE_OVERFLOW (edit allocated chunk with max data - triggers off-by-one)
The model must learn WHICH technique to apply based on heap state.
"""
import sys
import os
import json
import subprocess
import tempfile
import time
import random
import re
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "agent"))
sys.path.insert(0, str(ROOT / "model"))
from universal_grid import UniversalGridEncoder
from trm_heap import RMSNorm, SwiGLU, RecursionBlock
BINARY = ROOT / "ctf" / "vuln_heap"
HARNESS = ROOT / "harness" / "heapgrid_harness.so"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OP_MALLOC = 0
OP_FREE = 1
OP_WRITE_UAF = 2
OP_WRITE_OVF = 3
N_OPS = 4
SIZES = [0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x78, 0x80]
class MultiTechTRM(torch.nn.Module):
"""TRM for 4-class technique-aware op prediction."""
def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
n_outer=2, n_inner=3):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_dim)
self.y_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.z_init = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.block_z = RecursionBlock(hidden_dim)
self.block_y = RecursionBlock(hidden_dim)
self.pos_embed = torch.nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.out_norm = RMSNorm(hidden_dim)
self.n_outer = n_outer
self.n_inner = n_inner
self.head = torch.nn.Linear(hidden_dim, N_OPS)
def forward(self, x):
B = x.shape[0]
h = self.embed(x.reshape(B, -1)) + self.pos_embed
y = self.y_init.expand(B, -1, -1)
z = self.z_init.expand(B, -1, -1)
for _ in range(self.n_outer):
for _ in range(self.n_inner):
z = z + self.block_z(h + y + z)
y = y + self.block_y(y + z)
pooled = self.out_norm(y).mean(dim=1)
return self.head(pooled)
def run_and_dump(commands):
dump_path = tempfile.mktemp(suffix=".jsonl")
env = os.environ.copy()
env["LD_PRELOAD"] = str(HARNESS)
env["HEAPGRID_OUT"] = dump_path
subprocess.run([str(BINARY)], input=("\n".join(commands) + "\n5\n").encode(),
env=env, capture_output=True, timeout=10)
states = []
if os.path.exists(dump_path):
with open(dump_path) as f:
for line in f:
if line.strip():
states.append(json.loads(line.strip()))
os.unlink(dump_path)
return states
# ============================================================
# TECHNIQUE 1: Tcache poison (UAF fd write)
# Pattern: M+ F F W_UAF M+
# ============================================================
def gen_tcache_poison(n_pre=None):
size = random.choice([s for s in SIZES if s <= 0x70]) # tcache sizes
slots = random.sample(range(16), 8)
if n_pre is None:
n_pre = random.randint(2, 5)
commands, labels = [], []
alloc_slots = []
for i in range(n_pre):
commands.append(f"1 {slots[i]} {size}")
labels.append(OP_MALLOC)
alloc_slots.append(slots[i])
targets = random.sample(alloc_slots, 2)
for s in targets:
commands.append(f"4 {s}")
labels.append(OP_FREE)
# UAF write: edit freed chunk
commands.append(f"2 {targets[1]} {'41' * 8}")
labels.append(OP_WRITE_UAF)
drain = [s for s in slots if s not in alloc_slots][:2]
for s in drain:
commands.append(f"1 {s} {size}")
labels.append(OP_MALLOC)
return commands, labels, "tcache_poison"
# ============================================================
# TECHNIQUE 2: Off-by-one null byte overlap
# Pattern: M M M W_OVF F M
# Alloc 3 adjacent chunks, overflow A into B, free B (corrupted size), alloc over it
# ============================================================
def gen_off_by_one():
size = 0x78 # max size before rounding — null byte overflows into next chunk
slots = random.sample(range(16), 6)
commands, labels = [], []
# Alloc 3 adjacent: A, B, C
for i in range(3):
commands.append(f"1 {slots[i]} {size}")
labels.append(OP_MALLOC)
# Overflow A: write full 0x78 bytes, null byte hits B's prev_inuse
commands.append(f"2 {slots[0]} {'41' * size}")
labels.append(OP_WRITE_OVF)
# Free B (with corrupted size)
commands.append(f"4 {slots[1]}")
labels.append(OP_FREE)
# Alloc over corrupted region
commands.append(f"1 {slots[3]} {size}")
labels.append(OP_MALLOC)
return commands, labels, "off_by_one"
# ============================================================
# TECHNIQUE 3: Coalesce confusion
# Pattern: M M M F F M(larger) — free two adjacent, alloc large to get overlap
# ============================================================
def gen_coalesce():
small = random.choice([0x30, 0x40, 0x50])
slots = random.sample(range(16), 7)
commands, labels = [], []
n_pre = random.randint(3, 5)
for i in range(n_pre):
commands.append(f"1 {slots[i]} {small}")
labels.append(OP_MALLOC)
# Free two adjacent chunks (indices 1 and 2 should be adjacent)
commands.append(f"4 {slots[1]}")
labels.append(OP_FREE)
commands.append(f"4 {slots[2]}")
labels.append(OP_FREE)
# Alloc larger chunk that overlaps the coalesced region
big = small * 2 + 0x10
if big > 0x80:
big = 0x80
commands.append(f"1 {slots[n_pre]} {big}")
labels.append(OP_MALLOC)
return commands, labels, "coalesce"
# ============================================================
# BENIGN sequences
# ============================================================
def gen_benign():
n_ops = random.randint(4, 12)
commands, labels = [], []
allocated = {}
for _ in range(n_ops):
if not allocated or random.random() < 0.55:
free_s = [s for s in range(16) if s not in allocated]
if not free_s:
break
s = random.choice(free_s)
sz = random.choice(SIZES)
commands.append(f"1 {s} {sz}")
labels.append(OP_MALLOC)
allocated[s] = sz
else:
s = random.choice(list(allocated.keys()))
commands.append(f"4 {s}")
labels.append(OP_FREE)
del allocated[s]
return commands, labels, "benign"
# ============================================================
# DATA COLLECTION
# ============================================================
def collect_data(n_per_technique=200, n_benign=300):
all_grids, all_labels = [], []
technique_counts = {}
generators = [
("tcache_poison", gen_tcache_poison, n_per_technique),
("off_by_one", gen_off_by_one, n_per_technique),
("coalesce", gen_coalesce, n_per_technique),
("benign", gen_benign, n_benign),
]
for tech_name, gen_fn, count in generators:
print(f" Generating {count} {tech_name}...")
n_ok = 0
for _ in range(count):
if tech_name == "tcache_poison":
commands, op_labels, _ = gen_fn()
else:
commands, op_labels, _ = gen_fn()
states = run_and_dump(commands)
if len(states) < 2:
continue
encoder = UniversalGridEncoder()
# Phase-based labeling (what should happen NEXT)
n_free_seen = 0
n_ovf_seen = 0
write_done = False
for j, state in enumerate(states):
op = state.get("operation", "malloc")
if "free" in op.lower():
n_free_seen += 1
chunks = state.get("chunks", [])
n_alloc = sum(1 for c in chunks if c.get("state") == 1)
n_freed = sum(1 for c in chunks if c.get("state") == 2)
# Determine next label based on technique
if tech_name == "tcache_poison":
if n_free_seen == 0:
label = OP_FREE if (n_alloc >= 2 and j >= len(states) * 0.3) else OP_MALLOC
elif n_free_seen >= 2 and not write_done:
label = OP_WRITE_UAF
write_done = True
else:
label = OP_MALLOC
elif tech_name == "off_by_one":
if n_alloc >= 3 and n_free_seen == 0 and not write_done:
label = OP_WRITE_OVF
write_done = True
elif write_done and n_free_seen == 0:
label = OP_FREE
elif n_free_seen >= 1 and write_done:
label = OP_MALLOC
else:
label = OP_MALLOC
elif tech_name == "coalesce":
if n_free_seen == 0 and n_alloc >= 3:
label = OP_FREE
elif n_free_seen == 1:
label = OP_FREE # free second adjacent
elif n_free_seen >= 2:
label = OP_MALLOC # alloc large
else:
label = OP_MALLOC
else: # benign
if j + 1 < len(states):
next_op = states[j+1].get("operation", "malloc")
label = OP_FREE if "free" in next_op.lower() else OP_MALLOC
else:
label = OP_MALLOC
grid = encoder.encode(state)
all_grids.append(grid)
all_labels.append(label)
actual_op = OP_FREE if "free" in op.lower() else OP_MALLOC
encoder.record_action(actual_op, state.get("target_size", 0))
n_ok += 1
technique_counts[tech_name] = n_ok
print(f" Technique counts: {technique_counts}")
return np.stack(all_grids), np.array(all_labels, dtype=np.int64)
def train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128):
model.to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
X_t = torch.from_numpy(X).long().to(DEVICE)
y_t = torch.from_numpy(y).long().to(DEVICE)
n = len(X_t)
for ep in range(1, epochs + 1):
model.train()
perm = torch.randperm(n, device=DEVICE)
total_loss = correct = nb = 0
for i in range(0, n, bs):
idx = perm[i:i+bs]
logits = model(X_t[idx])
loss = F.cross_entropy(logits, y_t[idx])
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
total_loss += loss.item()
correct += (logits.argmax(1) == y_t[idx]).sum().item()
nb += 1
sched.step()
if ep % 25 == 0 or ep == 1:
print(f" Epoch {ep:3d} | loss={total_loss/nb:.4f} | acc={correct/n:.3f}")
def evaluate_live(model, n_trials=50, temperature=0.3):
"""Evaluate with hybrid policy: TRM picks op, rules handle W timing."""
model.eval().to(DEVICE)
results = {"uaf": 0, "ovf": 0, "correct_seq": 0, "seqs": []}
for trial in range(n_trials):
dump_path = tempfile.mktemp(suffix=".jsonl")
env = os.environ.copy()
env["LD_PRELOAD"] = str(HARNESS)
env["HEAPGRID_OUT"] = dump_path
proc = subprocess.Popen([str(BINARY)], stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
encoder = UniversalGridEncoder()
slots, freed, ops = {}, set(), []
did_uaf = did_ovf = False
n_consec_frees = 0
for step in range(20):
time.sleep(0.01)
state = {"chunks": []}
try:
with open(dump_path) as f:
lines = f.readlines()
if lines:
state = json.loads(lines[-1].strip())
except:
pass
grid = encoder.encode(state)
x = torch.from_numpy(grid).long().unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits / temperature, dim=1)
op = torch.multinomial(probs, 1).item()
# Rule overrides for W timing
if op == OP_WRITE_UAF and not freed:
op = OP_MALLOC # can't UAF without freed chunks
if op == OP_WRITE_OVF and not any(v for v in slots.values()):
op = OP_MALLOC # need allocated chunks
free_s = [s for s in range(8) if s not in slots]
alloc_s = [s for s, v in slots.items() if v and s not in freed]
cmd = None
size = 0
if op == OP_MALLOC and free_s:
s = random.choice(free_s)
size = random.choice(SIZES)
cmd = f"1 {s} {size}"
slots[s] = True
ops.append("M")
n_consec_frees = 0
elif op == OP_FREE and alloc_s:
s = random.choice(alloc_s)
cmd = f"4 {s}"
slots[s] = False
freed.add(s)
ops.append("F")
n_consec_frees += 1
elif op == OP_WRITE_UAF and freed:
s = random.choice(list(freed))
cmd = f"2 {s} {'41' * 8}"
did_uaf = True
ops.append("U")
elif op == OP_WRITE_OVF and alloc_s:
s = random.choice(alloc_s)
sz = slots.get(s) or 0x78
cmd = f"2 {s} {'41' * 0x78}" # max write triggers off-by-one
did_ovf = True
ops.append("O")
else:
ops.append("x")
if cmd:
proc.stdin.write((cmd + "\n").encode())
proc.stdin.flush()
encoder.record_action(op, size)
try:
proc.stdin.write(b"5\n"); proc.stdin.flush()
proc.wait(timeout=2)
except:
proc.kill()
# Check results
if did_uaf:
try:
with open(dump_path) as f:
for line in f:
st = json.loads(line.strip())
for c in st.get("chunks", []):
if c.get("fd", 0) == 0x4141414141414141:
results["uaf"] += 1
break
else:
continue
break
except:
pass
if did_ovf:
results["ovf"] += 1
seq = "".join(ops)
# Valid patterns: M+F+U (tcache poison), M+O+F+M (off-by-one), M+FF+M (coalesce)
if re.search(r"M+F+U", seq.replace("x", "")) or \
re.search(r"M+OF+M", seq.replace("x", "")) or \
re.search(r"M+FF+M", seq.replace("x", "")):
results["correct_seq"] += 1
results["seqs"].append(seq)
if os.path.exists(dump_path):
os.unlink(dump_path)
return results
def main():
print(f"Device: {DEVICE}")
print("\n=== Collecting multi-technique data ===")
X, y = collect_data(n_per_technique=300, n_benign=300)
print(f"Data: {len(X)} samples | M={sum(y==0)} F={sum(y==1)} U={sum(y==2)} O={sum(y==3)}")
print("\n=== Training on GPU ===")
model = MultiTechTRM(hidden_dim=128, n_outer=2, n_inner=3)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
train_gpu(model, X, y, epochs=300, lr=1e-3, bs=128)
print("\n=== Live evaluation (50 trials) ===")
results = evaluate_live(model, n_trials=50, temperature=0.3)
print(f"UAF writes achieved: {results['uaf']}/50 ({results['uaf']*2}%)")
print(f"Overflow writes: {results['ovf']}/50 ({results['ovf']*2}%)")
print(f"Valid exploit sequences: {results['correct_seq']}/50 ({results['correct_seq']*2}%)")
print("\nOp distribution in sequences:")
all_ops = "".join(results["seqs"])
for ch, name in [("M","malloc"), ("F","free"), ("U","uaf_write"), ("O","overflow"), ("x","skip")]:
print(f" {name}: {all_ops.count(ch)}")
print("\nSample sequences:")
for s in results["seqs"][:15]:
print(f" {s}")
# Per-technique analysis
print("\nPer-technique breakdown:")
has_uaf = sum(1 for s in results["seqs"] if "U" in s)
has_ovf = sum(1 for s in results["seqs"] if "O" in s)
has_both = sum(1 for s in results["seqs"] if "U" in s and "O" in s)
neither = sum(1 for s in results["seqs"] if "U" not in s and "O" not in s)
print(f" Used tcache poison (U): {has_uaf}/50")
print(f" Used off-by-one (O): {has_ovf}/50")
print(f" Used both: {has_both}/50")
print(f" Used neither: {neither}/50")
if __name__ == "__main__":
main()