| """ |
| monitor.py - HeapMonitor: the main user-facing API. |
| |
| Usage: |
| from heaptrm import HeapMonitor |
| |
| # Scan a binary |
| m = HeapMonitor() |
| result = m.scan("./target", args=["arg1"]) |
| print(result.verdict) # "EXPLOIT" or "CLEAN" |
| print(result.confidence) # 0.0-1.0 |
| print(result.corruptions) # list of detected corruption events |
| |
| # Attach to pwntools process |
| from pwn import process |
| p = process("./target") |
| m = HeapMonitor.attach(p) |
| m.check() # check current heap state |
| |
| # Live monitoring |
| m = HeapMonitor.live("./target") |
| for event in m.stream(): |
| print(event) # real-time heap events |
| """ |
|
|
| import json |
| import os |
| import subprocess |
| import tempfile |
| import time |
| import numpy as np |
| import torch |
| from pathlib import Path |
| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| from .classifier.model import HeapTRM |
| from .classifier.grid import state_to_grid, load_dump |
|
|
|
|
| @dataclass |
| class CorruptionEvent: |
| step: int |
| type: str |
| chunk_idx: int |
| detail: str |
|
|
|
|
| @dataclass |
| class ScanResult: |
| verdict: str |
| confidence: float |
| n_states: int |
| n_flagged: int |
| corruptions: List[CorruptionEvent] |
| exploit_states: List[int] |
| raw_probs: List[float] = field(default_factory=list) |
|
|
|
|
| class HeapMonitor: |
| """Heap exploit monitor using TRM classifier + LD_PRELOAD instrumentation.""" |
|
|
| |
| EXPLOIT_THRESHOLD = 0.7 |
| SUSPICIOUS_THRESHOLD = 0.3 |
|
|
| def __init__(self, model_path: Optional[str] = None, device: str = "auto"): |
| """Initialize with optional pre-trained model.""" |
| if device == "auto": |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
|
|
| self.model = HeapTRM(hidden_dim=128, n_outer=2, n_inner=3) |
|
|
| if model_path and Path(model_path).exists(): |
| self.model.load_state_dict( |
| torch.load(model_path, map_location=self.device, weights_only=True)) |
|
|
| self.model.to(self.device).eval() |
|
|
| |
| self._harness_path = self._find_harness() |
|
|
| def _find_harness(self) -> str: |
| """Locate the compiled harness .so file.""" |
| candidates = [ |
| Path(__file__).parent / "harness" / "heapgrid_v2.so", |
| Path(__file__).parent.parent / "harness" / "heapgrid_harness.so", |
| ] |
| for p in candidates: |
| if p.exists(): |
| return str(p.resolve()) |
|
|
| |
| src = Path(__file__).parent / "harness" / "heapgrid_v2.c" |
| if src.exists(): |
| out = src.with_suffix(".so") |
| subprocess.run( |
| ["gcc", "-shared", "-fPIC", "-O2", "-o", str(out), str(src), |
| "-ldl", "-pthread"], |
| capture_output=True |
| ) |
| if out.exists(): |
| return str(out.resolve()) |
|
|
| raise FileNotFoundError("Could not find or build heapgrid harness") |
|
|
| def scan(self, binary: str, args: list = None, |
| stdin_data: bytes = None, timeout: int = 30) -> ScanResult: |
| """ |
| Run a binary with heap instrumentation and classify its heap behavior. |
| |
| Args: |
| binary: path to the target binary |
| args: command line arguments |
| stdin_data: data to pipe to stdin |
| timeout: max runtime in seconds |
| |
| Returns: |
| ScanResult with verdict, confidence, and corruption events |
| """ |
| dump_path = tempfile.mktemp(suffix=".jsonl") |
|
|
| env = os.environ.copy() |
| env["LD_PRELOAD"] = self._harness_path |
| env["HEAPGRID_OUT"] = dump_path |
|
|
| cmd = [binary] + (args or []) |
|
|
| try: |
| proc = subprocess.run( |
| cmd, input=stdin_data, env=env, |
| capture_output=True, timeout=timeout |
| ) |
| except subprocess.TimeoutExpired: |
| pass |
|
|
| |
| states = [] |
| if os.path.exists(dump_path): |
| states = load_dump(Path(dump_path)) |
| os.unlink(dump_path) |
|
|
| return self._analyze(states) |
|
|
| def analyze_dump(self, dump_path: str) -> ScanResult: |
| """Analyze a pre-existing heap dump file.""" |
| states = load_dump(Path(dump_path)) |
| return self._analyze(states) |
|
|
| def _analyze(self, states: list) -> ScanResult: |
| """Classify a sequence of heap states.""" |
| if not states: |
| return ScanResult("CLEAN", 0.0, 0, 0, [], []) |
|
|
| |
| corruptions = [] |
| for state in states: |
| for c in state.get("corruptions", []): |
| corruptions.append(CorruptionEvent( |
| step=state.get("step", 0), |
| type=c.get("type", "unknown"), |
| chunk_idx=c.get("chunk_idx", -1), |
| detail=c.get("detail", ""), |
| )) |
|
|
| |
| grids = np.stack([state_to_grid(s) for s in states]) |
| X = torch.from_numpy(grids).long().to(self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(X) |
| probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() |
| preds = logits.argmax(dim=1).cpu().numpy() |
|
|
| max_prob = float(probs.max()) |
| flagged = [int(i) for i in range(len(preds)) if preds[i] == 1] |
|
|
| |
| |
| |
| has_corruption = len(corruptions) > 0 |
|
|
| if has_corruption: |
| |
| verdict = "EXPLOIT" |
| confidence = 0.95 |
| elif max_prob >= 0.9 and len(flagged) > len(states) * 0.8: |
| |
| verdict = "EXPLOIT" |
| confidence = max_prob |
| elif max_prob >= self.SUSPICIOUS_THRESHOLD: |
| verdict = "SUSPICIOUS" |
| confidence = max_prob |
| else: |
| verdict = "CLEAN" |
| confidence = 1.0 - max_prob |
|
|
| return ScanResult( |
| verdict=verdict, |
| confidence=confidence, |
| n_states=len(states), |
| n_flagged=len(flagged), |
| corruptions=corruptions, |
| exploit_states=flagged, |
| raw_probs=probs.tolist(), |
| ) |
|
|
| @classmethod |
| def attach(cls, process, **kwargs): |
| """ |
| Attach to a pwntools process. |
| |
| Usage: |
| from pwn import process |
| p = process("./target") |
| m = HeapMonitor.attach(p) |
| """ |
| |
| monitor = cls(**kwargs) |
| monitor._pwntools_proc = process |
| return monitor |
|
|
| def check(self) -> ScanResult: |
| """Check the attached pwntools process's current heap state.""" |
| if not hasattr(self, "_pwntools_proc"): |
| raise RuntimeError("No process attached. Use HeapMonitor.attach()") |
| |
| dump_path = os.environ.get("HEAPGRID_OUT", "heap_dump.jsonl") |
| if os.path.exists(dump_path): |
| return self.analyze_dump(dump_path) |
| return ScanResult("CLEAN", 0.0, 0, 0, [], []) |
|
|