heap-trm / heaptrm-cli /src /main.rs
amarck's picture
Risk scoring in heaptrm tool: pre-corruption exploit prediction
702e402
//! heaptrm - Heap exploit observability for LLM-assisted exploitation.
//!
//! Launches a target binary with LD_PRELOAD heap instrumentation,
//! provides a JSON protocol on stdin/stdout for LLM interaction.
//!
//! Protocol:
//! LLM sends: {"action": "send", "data": "1 0 64\n"}
//! Tool sends: {"heap": {...}, "changes": "...", "primitives": [...]}
//!
//! LLM sends: {"action": "observe"}
//! Tool sends: current heap state
//!
//! LLM sends: {"action": "quit"}
//! Tool sends: final summary and exits
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::fs;
use std::io::{self, BufRead, BufReader, Write};
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::thread;
use std::time::Duration;
// --- Harness output types ---
#[derive(Debug, Deserialize, Clone)]
struct RawChunk {
idx: usize,
addr: String,
state: u8,
#[serde(default)]
chunk_size: usize,
#[serde(default)]
fd: u64,
#[serde(default)]
fd_idx: i32,
#[serde(default)]
is_corrupted: u8,
#[serde(default)]
is_double_freed: u8,
#[serde(default)]
data_hex: String,
}
#[derive(Debug, Deserialize, Clone)]
struct RawCorruption {
#[serde(rename = "type")]
corruption_type: String,
chunk_idx: i32,
detail: String,
}
#[derive(Debug, Deserialize, Clone)]
struct RawState {
step: u32,
operation: String,
#[serde(default)]
corruption_count: u32,
#[serde(default)]
corruptions: Vec<RawCorruption>,
#[serde(default)]
chunks: Vec<RawChunk>,
}
// --- Output types ---
#[derive(Serialize)]
struct ChunkView {
index: usize,
address: String,
size: String,
state: String,
#[serde(skip_serializing_if = "Option::is_none")]
fd: Option<String>,
corrupted: bool,
}
#[derive(Serialize)]
struct BinView {
size: String,
count: usize,
entries: Vec<usize>,
}
#[derive(Serialize)]
struct Primitive {
name: String,
description: String,
chunks: Vec<usize>,
}
#[derive(Serialize)]
struct HeapView {
step: u32,
operation: String,
allocated: usize,
freed: usize,
risk_score: f64,
risk_factors: Vec<String>,
chunks: Vec<ChunkView>,
bins: Vec<BinView>,
corruptions: Vec<serde_json::Value>,
primitives: Vec<Primitive>,
summary: String,
}
#[derive(Serialize)]
struct Response {
#[serde(skip_serializing_if = "Option::is_none")]
heap: Option<HeapView>,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
addresses: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
exited: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
#[derive(Deserialize)]
struct Request {
action: String,
#[serde(default)]
data: String,
}
const QUARANTINE_FD: u64 = 0xFDFDFDFDFDFDFDFD;
unsafe fn libc_sigusr1(pid: i32) {
libc::kill(pid, 10); // SIGUSR1
}
/// Extract hex addresses (0x...) from text output.
fn extract_addresses(text: &str) -> Vec<String> {
let mut addrs = Vec::new();
let mut i = 0;
let bytes = text.as_bytes();
while i + 2 < bytes.len() {
if bytes[i] == b'0' && bytes[i + 1] == b'x' {
let start = i;
i += 2;
while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
i += 1;
}
let addr = &text[start..i];
// Only keep addresses that are plausibly heap/libc (>= 6 hex digits)
if addr.len() >= 8 {
addrs.push(addr.to_string());
}
} else {
i += 1;
}
}
addrs.sort();
addrs.dedup();
addrs
}
/// Build a Response from components.
fn make_response(
heap: Option<HeapView>,
captured: String,
child: &mut Child,
) -> Response {
let has_output = !captured.is_empty();
let addrs = if has_output { extract_addresses(&captured) } else { vec![] };
let exited = child.try_wait().ok().flatten().map(|_| true);
Response {
heap,
output: if has_output { Some(captured) } else { None },
addresses: if addrs.is_empty() { None } else { Some(addrs) },
exited,
error: None,
}
}
/// Compute exploit risk score from heap structure.
/// Encodes the patterns the TRM learned for pre-corruption prediction:
/// - Multiple same-size freed chunks (tcache/fastbin setup)
/// - Freed chunks adjacent to allocated chunks (overflow/UAF targets)
/// - Corrupted metadata (active exploitation)
/// Returns (score 0.0-1.0, list of risk factors).
fn compute_risk(state: &RawState) -> (f64, Vec<String>) {
let mut score: f64 = 0.0;
let mut factors = Vec::new();
let n_alloc = state.chunks.iter().filter(|c| c.state == 1).count();
let n_freed = state.chunks.iter().filter(|c| c.state == 2).count();
// Factor 1: Same-size freed chunks (tcache setup)
let mut size_freed: HashMap<usize, usize> = HashMap::new();
for c in &state.chunks {
if c.state == 2 && c.chunk_size > 0 {
*size_freed.entry(c.chunk_size).or_default() += 1;
}
}
for (sz, count) in &size_freed {
if *count >= 2 {
score += 0.25;
factors.push(format!("{} freed chunks of size 0x{:x} (tcache setup)", count, sz));
}
}
// Factor 2: Freed chunk adjacent to allocated chunk (UAF/overflow target)
for (i, c) in state.chunks.iter().enumerate() {
if c.state == 2 {
// Check neighbors
if i > 0 && state.chunks[i-1].state == 1 {
score += 0.15;
factors.push(format!("freed chunk {} adjacent to allocated chunk {}", c.idx, state.chunks[i-1].idx));
break; // count once
}
if i + 1 < state.chunks.len() && state.chunks[i+1].state == 1 {
score += 0.15;
factors.push(format!("freed chunk {} adjacent to allocated chunk {}", c.idx, state.chunks[i+1].idx));
break;
}
}
}
// Factor 3: Multiple allocations of same size (spray pattern)
let mut size_alloc: HashMap<usize, usize> = HashMap::new();
for c in &state.chunks {
if c.state == 1 && c.chunk_size > 0 {
*size_alloc.entry(c.chunk_size).or_default() += 1;
}
}
for (sz, count) in &size_alloc {
if *count >= 3 {
score += 0.15;
factors.push(format!("{} allocated chunks of size 0x{:x} (heap spray)", count, sz));
break;
}
}
// Factor 4: Freed chunks with corrupted fd (active tcache poison)
for c in &state.chunks {
if c.state == 2 && c.fd != 0 && c.fd != QUARANTINE_FD && c.fd_idx == -2 && !state.corruptions.is_empty() {
score += 0.4;
factors.push(format!("chunk {} has corrupted fd 0x{:x} (tcache poison)", c.idx, c.fd));
break;
}
}
// Factor 5: Active corruption events
if !state.corruptions.is_empty() {
score += 0.5;
for corr in &state.corruptions {
factors.push(format!("ACTIVE: {} — {}", corr.corruption_type, &corr.detail[..corr.detail.len().min(60)]));
}
}
// Factor 6: Both alloc and freed of same size (exploit in progress)
for sz in size_freed.keys() {
if size_alloc.contains_key(sz) {
let f = size_freed[sz];
let a = size_alloc[sz];
if f >= 2 && a >= 2 {
score += 0.2;
factors.push(format!("size 0x{:x}: {} freed + {} allocated (exploit pattern)", sz, f, a));
break;
}
}
}
(score.min(1.0), factors)
}
fn analyze_state(state: &RawState) -> HeapView {
let n_alloc = state.chunks.iter().filter(|c| c.state == 1).count();
let n_freed = state.chunks.iter().filter(|c| c.state == 2).count();
let chunks: Vec<ChunkView> = state.chunks.iter().map(|c| {
let fd = if c.fd != 0 && c.fd != QUARANTINE_FD {
Some(format!("0x{:x}", c.fd))
} else {
None
};
ChunkView {
index: c.idx,
address: c.addr.clone(),
size: format!("0x{:x}", c.chunk_size),
state: if c.state == 1 { "allocated".into() } else { "freed".into() },
fd,
corrupted: c.is_corrupted != 0,
}
}).collect();
// Bins
let mut size_bins: HashMap<usize, Vec<usize>> = HashMap::new();
for c in &state.chunks {
if c.state == 2 && c.chunk_size > 0 {
size_bins.entry(c.chunk_size).or_default().push(c.idx);
}
}
let bins: Vec<BinView> = size_bins.iter().map(|(sz, entries)| BinView {
size: format!("0x{:x}", sz),
count: entries.len(),
entries: entries.clone(),
}).collect();
// Corruptions as JSON values
let corruptions: Vec<serde_json::Value> = state.corruptions.iter().map(|c| {
serde_json::json!({
"type": c.corruption_type,
"chunk": c.chunk_idx,
"detail": c.detail,
})
}).collect();
// Primitives
let mut primitives = Vec::new();
// Only report tcache_poison if corruption was actually detected
// (safe-linking makes normal tcache fd look like external pointers)
let has_corruption = !state.corruptions.is_empty();
for c in &state.chunks {
if has_corruption && c.state == 2 && c.fd != 0 && c.fd != QUARANTINE_FD && c.fd_idx == -2 {
primitives.push(Primitive {
name: "tcache_poison".into(),
description: format!(
"Chunk {} has fd=0x{:x} outside heap. malloc(0x{:x}) returns controlled address.",
c.idx, c.fd, c.chunk_size.saturating_sub(0x10)
),
chunks: vec![c.idx],
});
}
if c.is_double_freed != 0 {
primitives.push(Primitive {
name: "double_free".into(),
description: format!("Chunk {} at {} freed multiple times.", c.idx, c.addr),
chunks: vec![c.idx],
});
}
}
for corr in &state.corruptions {
primitives.push(Primitive {
name: format!("corruption_{}", corr.corruption_type),
description: corr.detail.clone(),
chunks: vec![corr.chunk_idx as usize],
});
}
// Summary
let mut summary = format!("Step {}: {} | {} alloc, {} freed", state.step, state.operation, n_alloc, n_freed);
for corr in &state.corruptions {
summary.push_str(&format!("\n!! {}: {}", corr.corruption_type, corr.detail));
}
let prim_names: Vec<&str> = primitives.iter()
.filter(|p| !p.name.starts_with("corruption_"))
.map(|p| p.name.as_str())
.collect();
if !prim_names.is_empty() {
summary.push_str(&format!("\nPrimitives: {}", prim_names.join(", ")));
}
// Compute risk score
let (risk_score, risk_factors) = compute_risk(state);
if risk_score > 0.3 {
summary.push_str(&format!("\n⚠ Risk: {:.0}% — {}", risk_score * 100.0,
risk_factors.first().map(|s| s.as_str()).unwrap_or("")));
}
HeapView { step: state.step, operation: state.operation.clone(), allocated: n_alloc, freed: n_freed, risk_score, risk_factors, chunks, bins, corruptions, primitives, summary }
}
fn find_harness() -> Option<PathBuf> {
let candidates = [
"heapgrid_v2.so",
"heaptrm/harness/heapgrid_v2.so",
"harness/heapgrid_harness.so",
"../heaptrm/harness/heapgrid_v2.so",
];
for c in &candidates {
let p = PathBuf::from(c);
if p.exists() {
return Some(fs::canonicalize(p).ok()?);
}
}
// Try compile
for src in &["heaptrm/harness/heapgrid_v2.c", "heapgrid_v2.c"] {
let s = PathBuf::from(src);
if s.exists() {
let out = s.with_extension("so");
if Command::new("gcc").args(["-shared","-fPIC","-O2","-o"]).arg(&out).arg(&s).args(["-ldl","-pthread"]).status().map(|s| s.success()).unwrap_or(false) {
return Some(fs::canonicalize(out).ok()?);
}
}
}
None
}
fn main() {
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("heaptrm — heap exploit observability for LLM-assisted exploitation");
eprintln!();
eprintln!("Usage: heaptrm <binary> [args...]");
eprintln!();
eprintln!("Launches <binary> with heap instrumentation. Reads JSON from stdin,");
eprintln!("writes heap observations to stdout.");
eprintln!();
eprintln!("Commands:");
eprintln!(" {{\"action\": \"send\", \"data\": \"...\"}} send data to binary stdin");
eprintln!(" {{\"action\": \"observe\"}} get current heap state");
eprintln!(" {{\"action\": \"check\"}} force heap validation (detects corruption from writes)");
eprintln!(" {{\"action\": \"quit\"}} exit");
std::process::exit(1);
}
let binary = &args[1];
let binary_args = &args[2..];
let harness = find_harness().unwrap_or_else(|| {
eprintln!("Error: Cannot find heapgrid_v2.so");
std::process::exit(1);
});
let dump_path = format!("/tmp/heaptrm_{}.jsonl", std::process::id());
let mut child: Child = Command::new(binary)
.args(binary_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.env("LD_PRELOAD", &harness)
.env("HEAPGRID_OUT", &dump_path)
.spawn()
.unwrap_or_else(|e| { eprintln!("Failed to launch: {}", e); std::process::exit(1); });
let mut child_stdin = child.stdin.take().expect("stdin");
let child_stdout = child.stdout.take().expect("stdout");
// Spawn thread to capture binary's stdout (for address leaks etc)
use std::sync::{Arc, Mutex};
let output_buf = Arc::new(Mutex::new(String::new()));
let output_buf_writer = output_buf.clone();
thread::spawn(move || {
let reader = BufReader::new(child_stdout);
for line in reader.lines() {
if let Ok(line) = line {
let mut buf = output_buf_writer.lock().unwrap();
buf.push_str(&line);
buf.push('\n');
// Cap at 16KB
if buf.len() > 16384 {
let drain = buf.len() - 8192;
buf.drain(..drain);
}
}
}
});
thread::sleep(Duration::from_millis(50));
let stdin = io::stdin();
let stdout = io::stdout();
let mut out = stdout.lock();
let mut last_pos: u64 = 0;
let mut last_state: Option<RawState> = None;
let read_latest = |pos: &mut u64| -> Option<RawState> {
let content = fs::read_to_string(&dump_path).ok()?;
let start = *pos as usize;
if start >= content.len() { return None; }
let new = &content[start..];
let mut last = None;
for line in new.lines() {
if let Ok(s) = serde_json::from_str::<RawState>(line) {
last = Some(s);
}
}
*pos = content.len() as u64;
last
};
for line in stdin.lock().lines() {
let line = match line { Ok(l) => l, Err(_) => break };
if line.trim().is_empty() { continue; }
let req: Request = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
let r = Response { heap: None, output: None, addresses: None, exited: None, error: Some(format!("Bad JSON: {}", e)) };
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
continue;
}
};
// Helper: drain captured stdout
let drain_output = || -> String {
let mut buf = output_buf.lock().unwrap();
let s = buf.clone();
buf.clear();
s
};
match req.action.as_str() {
"send" => {
child_stdin.write_all(req.data.as_bytes()).ok();
child_stdin.flush().ok();
thread::sleep(Duration::from_millis(20));
if let Some(s) = read_latest(&mut last_pos) {
last_state = Some(s);
}
let captured = drain_output();
let heap = last_state.as_ref().map(|s| analyze_state(s));
let r = make_response(heap, captured, &mut child);
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
}
"recv" => {
thread::sleep(Duration::from_millis(10));
let captured = drain_output();
let r = make_response(None, captured, &mut child);
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
}
"observe" => {
if let Some(s) = read_latest(&mut last_pos) {
last_state = Some(s);
}
let heap = last_state.as_ref().map(|s| analyze_state(s));
let r = make_response(heap, String::new(), &mut child);
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
}
"check" => {
unsafe { libc_sigusr1(child.id() as i32); }
thread::sleep(Duration::from_millis(30));
if let Some(s) = read_latest(&mut last_pos) {
last_state = Some(s);
}
let heap = last_state.as_ref().map(|s| analyze_state(s));
let r = make_response(heap, String::new(), &mut child);
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
}
"quit" => {
child.kill().ok();
fs::remove_file(&dump_path).ok();
break;
}
_ => {
let r = Response { heap: None, output: None, addresses: None, exited: None, error: Some(format!("Unknown: {}", req.action)) };
writeln!(out, "{}", serde_json::to_string(&r).unwrap()).ok();
out.flush().ok();
}
}
}
child.kill().ok();
child.wait().ok();
fs::remove_file(&dump_path).ok();
}