""" monitor.py - HeapMonitor: the main user-facing API. Usage: from heaptrm import HeapMonitor # Scan a binary m = HeapMonitor() result = m.scan("./target", args=["arg1"]) print(result.verdict) # "EXPLOIT" or "CLEAN" print(result.confidence) # 0.0-1.0 print(result.corruptions) # list of detected corruption events # Attach to pwntools process from pwn import process p = process("./target") m = HeapMonitor.attach(p) m.check() # check current heap state # Live monitoring m = HeapMonitor.live("./target") for event in m.stream(): print(event) # real-time heap events """ import json import os import subprocess import tempfile import time import numpy as np import torch from pathlib import Path from dataclasses import dataclass, field from typing import List, Optional from .classifier.model import HeapTRM from .classifier.grid import state_to_grid, load_dump @dataclass class CorruptionEvent: step: int type: str # "metadata_corrupt", "uaf_write", "double_free", "overflow" chunk_idx: int detail: str @dataclass class ScanResult: verdict: str # "EXPLOIT", "SUSPICIOUS", "CLEAN" confidence: float # max exploit probability n_states: int # total heap states observed n_flagged: int # states classified as exploit corruptions: List[CorruptionEvent] # detected corruption events exploit_states: List[int] # indices of flagged states raw_probs: List[float] = field(default_factory=list) class HeapMonitor: """Heap exploit monitor using TRM classifier + LD_PRELOAD instrumentation.""" # Thresholds EXPLOIT_THRESHOLD = 0.7 SUSPICIOUS_THRESHOLD = 0.3 def __init__(self, model_path: Optional[str] = None, device: str = "auto"): """Initialize with optional pre-trained model.""" if device == "auto": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) self.model = HeapTRM(hidden_dim=128, n_outer=2, n_inner=3) if model_path and Path(model_path).exists(): self.model.load_state_dict( torch.load(model_path, map_location=self.device, weights_only=True)) self.model.to(self.device).eval() # Find harness .so self._harness_path = self._find_harness() def _find_harness(self) -> str: """Locate the compiled harness .so file.""" candidates = [ Path(__file__).parent / "harness" / "heapgrid_v2.so", Path(__file__).parent.parent / "harness" / "heapgrid_harness.so", ] for p in candidates: if p.exists(): return str(p.resolve()) # Try to compile it src = Path(__file__).parent / "harness" / "heapgrid_v2.c" if src.exists(): out = src.with_suffix(".so") subprocess.run( ["gcc", "-shared", "-fPIC", "-O2", "-o", str(out), str(src), "-ldl", "-pthread"], capture_output=True ) if out.exists(): return str(out.resolve()) raise FileNotFoundError("Could not find or build heapgrid harness") def scan(self, binary: str, args: list = None, stdin_data: bytes = None, timeout: int = 30) -> ScanResult: """ Run a binary with heap instrumentation and classify its heap behavior. Args: binary: path to the target binary args: command line arguments stdin_data: data to pipe to stdin timeout: max runtime in seconds Returns: ScanResult with verdict, confidence, and corruption events """ dump_path = tempfile.mktemp(suffix=".jsonl") env = os.environ.copy() env["LD_PRELOAD"] = self._harness_path env["HEAPGRID_OUT"] = dump_path cmd = [binary] + (args or []) try: proc = subprocess.run( cmd, input=stdin_data, env=env, capture_output=True, timeout=timeout ) except subprocess.TimeoutExpired: pass # Load and analyze dump states = [] if os.path.exists(dump_path): states = load_dump(Path(dump_path)) os.unlink(dump_path) return self._analyze(states) def analyze_dump(self, dump_path: str) -> ScanResult: """Analyze a pre-existing heap dump file.""" states = load_dump(Path(dump_path)) return self._analyze(states) def _analyze(self, states: list) -> ScanResult: """Classify a sequence of heap states.""" if not states: return ScanResult("CLEAN", 0.0, 0, 0, [], []) # Extract corruption events from v2 harness data corruptions = [] for state in states: for c in state.get("corruptions", []): corruptions.append(CorruptionEvent( step=state.get("step", 0), type=c.get("type", "unknown"), chunk_idx=c.get("chunk_idx", -1), detail=c.get("detail", ""), )) # Encode to grids and classify grids = np.stack([state_to_grid(s) for s in states]) X = torch.from_numpy(grids).long().to(self.device) with torch.no_grad(): logits = self.model(X) probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() preds = logits.argmax(dim=1).cpu().numpy() max_prob = float(probs.max()) flagged = [int(i) for i in range(len(preds)) if preds[i] == 1] # Determine verdict # Primary: corruption events from v2 harness (zero false positives) # Secondary: ML classifier (higher false positive rate, used for SUSPICIOUS only) has_corruption = len(corruptions) > 0 if has_corruption: # Hard evidence from harness — definitive verdict = "EXPLOIT" confidence = 0.95 elif max_prob >= 0.9 and len(flagged) > len(states) * 0.8: # ML: only EXPLOIT if very high confidence AND vast majority flagged verdict = "EXPLOIT" confidence = max_prob elif max_prob >= self.SUSPICIOUS_THRESHOLD: verdict = "SUSPICIOUS" confidence = max_prob else: verdict = "CLEAN" confidence = 1.0 - max_prob return ScanResult( verdict=verdict, confidence=confidence, n_states=len(states), n_flagged=len(flagged), corruptions=corruptions, exploit_states=flagged, raw_probs=probs.tolist(), ) @classmethod def attach(cls, process, **kwargs): """ Attach to a pwntools process. Usage: from pwn import process p = process("./target") m = HeapMonitor.attach(p) """ # pwntools integration: read from process's heap dump monitor = cls(**kwargs) monitor._pwntools_proc = process return monitor def check(self) -> ScanResult: """Check the attached pwntools process's current heap state.""" if not hasattr(self, "_pwntools_proc"): raise RuntimeError("No process attached. Use HeapMonitor.attach()") # Read the dump file dump_path = os.environ.get("HEAPGRID_OUT", "heap_dump.jsonl") if os.path.exists(dump_path): return self.analyze_dump(dump_path) return ScanResult("CLEAN", 0.0, 0, 0, [], [])