heap-trm / heaptrm /monitor.py
amarck's picture
Tuned verdict logic: F1=1.0 on CVE patterns, zero false positives
c04e9ff
"""
monitor.py - HeapMonitor: the main user-facing API.
Usage:
from heaptrm import HeapMonitor
# Scan a binary
m = HeapMonitor()
result = m.scan("./target", args=["arg1"])
print(result.verdict) # "EXPLOIT" or "CLEAN"
print(result.confidence) # 0.0-1.0
print(result.corruptions) # list of detected corruption events
# Attach to pwntools process
from pwn import process
p = process("./target")
m = HeapMonitor.attach(p)
m.check() # check current heap state
# Live monitoring
m = HeapMonitor.live("./target")
for event in m.stream():
print(event) # real-time heap events
"""
import json
import os
import subprocess
import tempfile
import time
import numpy as np
import torch
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional
from .classifier.model import HeapTRM
from .classifier.grid import state_to_grid, load_dump
@dataclass
class CorruptionEvent:
step: int
type: str # "metadata_corrupt", "uaf_write", "double_free", "overflow"
chunk_idx: int
detail: str
@dataclass
class ScanResult:
verdict: str # "EXPLOIT", "SUSPICIOUS", "CLEAN"
confidence: float # max exploit probability
n_states: int # total heap states observed
n_flagged: int # states classified as exploit
corruptions: List[CorruptionEvent] # detected corruption events
exploit_states: List[int] # indices of flagged states
raw_probs: List[float] = field(default_factory=list)
class HeapMonitor:
"""Heap exploit monitor using TRM classifier + LD_PRELOAD instrumentation."""
# Thresholds
EXPLOIT_THRESHOLD = 0.7
SUSPICIOUS_THRESHOLD = 0.3
def __init__(self, model_path: Optional[str] = None, device: str = "auto"):
"""Initialize with optional pre-trained model."""
if device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.model = HeapTRM(hidden_dim=128, n_outer=2, n_inner=3)
if model_path and Path(model_path).exists():
self.model.load_state_dict(
torch.load(model_path, map_location=self.device, weights_only=True))
self.model.to(self.device).eval()
# Find harness .so
self._harness_path = self._find_harness()
def _find_harness(self) -> str:
"""Locate the compiled harness .so file."""
candidates = [
Path(__file__).parent / "harness" / "heapgrid_v2.so",
Path(__file__).parent.parent / "harness" / "heapgrid_harness.so",
]
for p in candidates:
if p.exists():
return str(p.resolve())
# Try to compile it
src = Path(__file__).parent / "harness" / "heapgrid_v2.c"
if src.exists():
out = src.with_suffix(".so")
subprocess.run(
["gcc", "-shared", "-fPIC", "-O2", "-o", str(out), str(src),
"-ldl", "-pthread"],
capture_output=True
)
if out.exists():
return str(out.resolve())
raise FileNotFoundError("Could not find or build heapgrid harness")
def scan(self, binary: str, args: list = None,
stdin_data: bytes = None, timeout: int = 30) -> ScanResult:
"""
Run a binary with heap instrumentation and classify its heap behavior.
Args:
binary: path to the target binary
args: command line arguments
stdin_data: data to pipe to stdin
timeout: max runtime in seconds
Returns:
ScanResult with verdict, confidence, and corruption events
"""
dump_path = tempfile.mktemp(suffix=".jsonl")
env = os.environ.copy()
env["LD_PRELOAD"] = self._harness_path
env["HEAPGRID_OUT"] = dump_path
cmd = [binary] + (args or [])
try:
proc = subprocess.run(
cmd, input=stdin_data, env=env,
capture_output=True, timeout=timeout
)
except subprocess.TimeoutExpired:
pass
# Load and analyze dump
states = []
if os.path.exists(dump_path):
states = load_dump(Path(dump_path))
os.unlink(dump_path)
return self._analyze(states)
def analyze_dump(self, dump_path: str) -> ScanResult:
"""Analyze a pre-existing heap dump file."""
states = load_dump(Path(dump_path))
return self._analyze(states)
def _analyze(self, states: list) -> ScanResult:
"""Classify a sequence of heap states."""
if not states:
return ScanResult("CLEAN", 0.0, 0, 0, [], [])
# Extract corruption events from v2 harness data
corruptions = []
for state in states:
for c in state.get("corruptions", []):
corruptions.append(CorruptionEvent(
step=state.get("step", 0),
type=c.get("type", "unknown"),
chunk_idx=c.get("chunk_idx", -1),
detail=c.get("detail", ""),
))
# Encode to grids and classify
grids = np.stack([state_to_grid(s) for s in states])
X = torch.from_numpy(grids).long().to(self.device)
with torch.no_grad():
logits = self.model(X)
probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
preds = logits.argmax(dim=1).cpu().numpy()
max_prob = float(probs.max())
flagged = [int(i) for i in range(len(preds)) if preds[i] == 1]
# Determine verdict
# Primary: corruption events from v2 harness (zero false positives)
# Secondary: ML classifier (higher false positive rate, used for SUSPICIOUS only)
has_corruption = len(corruptions) > 0
if has_corruption:
# Hard evidence from harness — definitive
verdict = "EXPLOIT"
confidence = 0.95
elif max_prob >= 0.9 and len(flagged) > len(states) * 0.8:
# ML: only EXPLOIT if very high confidence AND vast majority flagged
verdict = "EXPLOIT"
confidence = max_prob
elif max_prob >= self.SUSPICIOUS_THRESHOLD:
verdict = "SUSPICIOUS"
confidence = max_prob
else:
verdict = "CLEAN"
confidence = 1.0 - max_prob
return ScanResult(
verdict=verdict,
confidence=confidence,
n_states=len(states),
n_flagged=len(flagged),
corruptions=corruptions,
exploit_states=flagged,
raw_probs=probs.tolist(),
)
@classmethod
def attach(cls, process, **kwargs):
"""
Attach to a pwntools process.
Usage:
from pwn import process
p = process("./target")
m = HeapMonitor.attach(p)
"""
# pwntools integration: read from process's heap dump
monitor = cls(**kwargs)
monitor._pwntools_proc = process
return monitor
def check(self) -> ScanResult:
"""Check the attached pwntools process's current heap state."""
if not hasattr(self, "_pwntools_proc"):
raise RuntimeError("No process attached. Use HeapMonitor.attach()")
# Read the dump file
dump_path = os.environ.get("HEAPGRID_OUT", "heap_dump.jsonl")
if os.path.exists(dump_path):
return self.analyze_dump(dump_path)
return ScanResult("CLEAN", 0.0, 0, 0, [], [])