heap-trm / simulator /heap_sim.py
amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
"""
heap_sim.py - Lightweight ptmalloc2 heap allocator simulator.
Simulates the core glibc heap allocator behavior:
- Chunk metadata (prev_size, size, flags, fd, bk)
- Tcache (per-size LIFO, 7 entries max, sizes 0x20-0x410)
- Fastbins (LIFO, sizes 0x20-0x80)
- Unsorted bin (doubly-linked)
- Forward/backward coalescing
- Top chunk splitting
- Request size -> chunk size conversion
Does NOT simulate:
- Safe linking (we want to model the exploitable state)
- mmap'd chunks
- Thread arenas
- Large bin sorting
This is deliberately simplified to be fast enough for self-play training.
"""
import numpy as np
from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass, field
# Constants (64-bit)
SIZE_SZ = 8
MALLOC_ALIGN_MASK = 0xF
MINSIZE = 0x20
TCACHE_MAX_BINS = 64
TCACHE_FILL_COUNT = 7
MAX_FASTBIN_SIZE = 0x80
NUM_FASTBINS = 7 # sizes 0x20, 0x30, ..., 0x80
INITIAL_HEAP_BASE = 0x555555559000
TOP_CHUNK_INITIAL_SIZE = 0x21000
# Flags
PREV_INUSE = 0x1
IS_MMAPPED = 0x2
NON_MAIN_ARENA = 0x4
def request_to_chunk_size(req: int) -> int:
"""Convert malloc request size to actual chunk size."""
if req + SIZE_SZ + MALLOC_ALIGN_MASK < MINSIZE:
return MINSIZE
return (req + SIZE_SZ + MALLOC_ALIGN_MASK) & ~MALLOC_ALIGN_MASK
def tcache_idx(chunk_size: int) -> int:
"""Tcache bin index for a chunk size."""
return (chunk_size - MINSIZE) // 0x10
def fastbin_idx(chunk_size: int) -> int:
"""Fastbin index for a chunk size."""
return (chunk_size - MINSIZE) // 0x10
@dataclass
class Chunk:
"""Represents a heap chunk."""
addr: int # address of chunk start (before user data)
prev_size: int = 0
size: int = 0 # includes flags in low 3 bits
allocated: bool = True
fd: int = 0 # forward pointer (freed chunks)
bk: int = 0 # backward pointer (freed chunks)
user_data: bytes = b'' # first N bytes of user data
@property
def real_size(self) -> int:
return self.size & ~0x7
@property
def prev_inuse(self) -> bool:
return bool(self.size & PREV_INUSE)
@property
def user_addr(self) -> int:
return self.addr + 2 * SIZE_SZ # skip prev_size + size
@property
def next_chunk_addr(self) -> int:
return self.addr + self.real_size
class HeapSimulator:
"""Simulates the glibc ptmalloc2 heap allocator."""
def __init__(self, heap_base: int = INITIAL_HEAP_BASE):
self.heap_base = heap_base
self.chunks: Dict[int, Chunk] = {} # addr -> Chunk
# Tcache: list of chunk addrs per size class
self.tcache: List[List[int]] = [[] for _ in range(TCACHE_MAX_BINS)]
# Fastbins: list of chunk addrs per size class
self.fastbins: List[List[int]] = [[] for _ in range(NUM_FASTBINS)]
# Unsorted bin: list of chunk addrs
self.unsorted_bin: List[int] = []
# Top chunk
self.top_addr = heap_base
self.top_size = TOP_CHUNK_INITIAL_SIZE
# Tracking
self.alloc_count = 0
self.free_count = 0
self.step = 0
self.history: List[dict] = []
# User-facing slots (like CTF menu: slots 0-15)
self.slots: Dict[int, int] = {} # slot_idx -> user_addr
def _record(self, op: str, **kwargs):
self.history.append({"step": self.step, "op": op, **kwargs})
self.step += 1
# ============================================================
# MALLOC
# ============================================================
def malloc(self, req_size: int, slot: Optional[int] = None) -> Optional[int]:
"""Allocate a chunk. Returns user pointer or None."""
chunk_size = request_to_chunk_size(req_size)
user_addr = None
# 1. Try tcache
if chunk_size <= 0x410:
idx = tcache_idx(chunk_size)
if idx < TCACHE_MAX_BINS and self.tcache[idx]:
addr = self.tcache[idx].pop()
chunk = self.chunks[addr]
chunk.allocated = True
chunk.fd = 0
chunk.bk = 0
user_addr = chunk.user_addr
self._set_next_prev_inuse(chunk, True)
# 2. Try fastbin
if user_addr is None and chunk_size <= MAX_FASTBIN_SIZE:
idx = fastbin_idx(chunk_size)
if idx < NUM_FASTBINS and self.fastbins[idx]:
addr = self.fastbins[idx].pop()
chunk = self.chunks[addr]
chunk.allocated = True
chunk.fd = 0
chunk.bk = 0
user_addr = chunk.user_addr
self._set_next_prev_inuse(chunk, True)
# 3. Try unsorted bin (first fit)
if user_addr is None:
for i, addr in enumerate(self.unsorted_bin):
chunk = self.chunks[addr]
if chunk.real_size >= chunk_size:
self.unsorted_bin.pop(i)
remainder = chunk.real_size - chunk_size
if remainder >= MINSIZE:
# Split
self._split_chunk(chunk, chunk_size)
chunk.allocated = True
chunk.size = chunk_size | PREV_INUSE
chunk.fd = 0
chunk.bk = 0
user_addr = chunk.user_addr
self._set_next_prev_inuse(chunk, True)
break
# 4. Split top chunk
if user_addr is None:
if self.top_size >= chunk_size + MINSIZE:
addr = self.top_addr
chunk = Chunk(
addr=addr,
size=chunk_size | PREV_INUSE,
allocated=True,
)
self.chunks[addr] = chunk
self.top_addr = addr + chunk_size
self.top_size -= chunk_size
user_addr = chunk.user_addr
else:
return None # OOM
if slot is not None and user_addr is not None:
self.slots[slot] = user_addr
self.alloc_count += 1
self._record("malloc", size=req_size, chunk_size=chunk_size,
user_addr=user_addr, slot=slot)
return user_addr
# ============================================================
# FREE
# ============================================================
def free(self, user_addr: int = 0, slot: Optional[int] = None) -> bool:
"""Free a chunk by user pointer or slot. Returns success."""
if slot is not None:
user_addr = self.slots.get(slot, 0)
if not user_addr:
return False
chunk_addr = user_addr - 2 * SIZE_SZ
chunk = self.chunks.get(chunk_addr)
if chunk is None:
return False
chunk_size = chunk.real_size
# 1. Try tcache
if chunk_size <= 0x410:
idx = tcache_idx(chunk_size)
if idx < TCACHE_MAX_BINS and len(self.tcache[idx]) < TCACHE_FILL_COUNT:
chunk.allocated = False
chunk.fd = self.tcache[idx][-1] if self.tcache[idx] else 0
chunk.bk = 0
self.tcache[idx].append(chunk_addr)
self.free_count += 1
self._record("free", user_addr=user_addr, bin="tcache",
chunk_size=chunk_size, slot=slot)
return True
# 2. Fastbin
if chunk_size <= MAX_FASTBIN_SIZE:
idx = fastbin_idx(chunk_size)
if idx < NUM_FASTBINS:
chunk.allocated = False
chunk.fd = self.fastbins[idx][-1] if self.fastbins[idx] else 0
self.fastbins[idx].append(chunk_addr)
self.free_count += 1
self._record("free", user_addr=user_addr, bin="fastbin",
chunk_size=chunk_size, slot=slot)
return True
# 3. Coalesce and put in unsorted bin
chunk.allocated = False
self._coalesce_and_unsort(chunk)
self.free_count += 1
self._record("free", user_addr=user_addr, bin="unsorted",
chunk_size=chunk_size, slot=slot)
return True
# ============================================================
# WRITE (simulate vulnerability)
# ============================================================
def write(self, user_addr: int, data: bytes, overflow: int = 0) -> bool:
"""Write data to a chunk. overflow > 0 allows OOB write."""
chunk_addr = user_addr - 2 * SIZE_SZ
chunk = self.chunks.get(chunk_addr)
if chunk is None:
return False
chunk.user_data = data
# Simulate overflow: corrupt next chunk's metadata
if overflow > 0:
next_addr = chunk.next_chunk_addr
next_chunk = self.chunks.get(next_addr)
if next_chunk and overflow >= 1:
# Off-by-one null byte: corrupt size field's LSB
next_chunk.size = next_chunk.size & ~0xFF
# More overflow: can corrupt fd/bk
if overflow >= SIZE_SZ and not next_chunk.allocated:
# Overwrite fd pointer
if len(data) > chunk.real_size - 2 * SIZE_SZ:
overflow_data = data[chunk.real_size - 2 * SIZE_SZ:]
if len(overflow_data) >= 8:
next_chunk.fd = int.from_bytes(
overflow_data[:8], 'little')
self._record("write", user_addr=user_addr,
data_len=len(data), overflow=overflow)
return True
def write_to_freed(self, user_addr: int, fd_value: int) -> bool:
"""UAF: write fd pointer of a freed chunk."""
chunk_addr = user_addr - 2 * SIZE_SZ
chunk = self.chunks.get(chunk_addr)
if chunk is None or chunk.allocated:
return False
chunk.fd = fd_value
self._record("write_freed", user_addr=user_addr, fd_value=fd_value)
return True
# ============================================================
# INTERNAL HELPERS
# ============================================================
def _set_next_prev_inuse(self, chunk: Chunk, inuse: bool):
"""Set the PREV_INUSE bit of the chunk following `chunk`."""
next_addr = chunk.next_chunk_addr
next_chunk = self.chunks.get(next_addr)
if next_chunk:
if inuse:
next_chunk.size |= PREV_INUSE
else:
next_chunk.size &= ~PREV_INUSE
next_chunk.prev_size = chunk.real_size
def _split_chunk(self, chunk: Chunk, new_size: int):
"""Split chunk into (new_size) and remainder."""
remainder_addr = chunk.addr + new_size
remainder_size = chunk.real_size - new_size
remainder = Chunk(
addr=remainder_addr,
prev_size=new_size,
size=remainder_size | PREV_INUSE,
allocated=False,
)
self.chunks[remainder_addr] = remainder
self.unsorted_bin.append(remainder_addr)
def _coalesce_and_unsort(self, chunk: Chunk):
"""Coalesce with neighbors and put in unsorted bin."""
addr = chunk.addr
size = chunk.real_size
# Forward coalesce: merge with next chunk if free
next_addr = addr + size
next_chunk = self.chunks.get(next_addr)
if next_chunk and not next_chunk.allocated:
# Remove next from whatever bin it's in
self._remove_from_bins(next_addr)
size += next_chunk.real_size
del self.chunks[next_addr]
# Backward coalesce: merge with prev chunk if free
if not chunk.prev_inuse and chunk.prev_size > 0:
prev_addr = addr - chunk.prev_size
prev_chunk = self.chunks.get(prev_addr)
if prev_chunk and not prev_chunk.allocated:
self._remove_from_bins(prev_addr)
size += prev_chunk.real_size
del self.chunks[addr]
addr = prev_addr
chunk = prev_chunk
# Update chunk
chunk.addr = addr
chunk.size = size | PREV_INUSE # prev of coalesced is always inuse
chunk.allocated = False
self.chunks[addr] = chunk
# Check if coalescing into top
if addr + size == self.top_addr:
self.top_addr = addr
self.top_size += size
if addr in self.chunks:
del self.chunks[addr]
return
self.unsorted_bin.append(addr)
self._set_next_prev_inuse(chunk, False)
def _remove_from_bins(self, addr: int):
"""Remove a chunk addr from whichever bin it's in."""
for tc in self.tcache:
if addr in tc:
tc.remove(addr)
return
for fb in self.fastbins:
if addr in fb:
fb.remove(addr)
return
if addr in self.unsorted_bin:
self.unsorted_bin.remove(addr)
# ============================================================
# STATE OBSERVATION
# ============================================================
def get_state(self) -> dict:
"""Return current heap state as a dict (for grid encoding)."""
chunks_info = []
for addr in sorted(self.chunks.keys()):
c = self.chunks[addr]
# Determine which bin it's in
bin_type = "none"
if not c.allocated:
for i, tc in enumerate(self.tcache):
if addr in tc:
bin_type = f"tcache_{i}"
break
else:
for i, fb in enumerate(self.fastbins):
if addr in fb:
bin_type = f"fastbin_{i}"
break
else:
if addr in self.unsorted_bin:
bin_type = "unsorted"
# Which slot points here?
slot = None
for s, ua in self.slots.items():
if ua == c.user_addr:
slot = s
break
chunks_info.append({
"addr": addr,
"size": c.real_size,
"allocated": c.allocated,
"prev_inuse": c.prev_inuse,
"fd": c.fd,
"bk": c.bk,
"bin": bin_type,
"slot": slot,
"user_data_len": len(c.user_data),
})
return {
"step": self.step,
"n_chunks": len(self.chunks),
"chunks": chunks_info,
"tcache_counts": [len(tc) for tc in self.tcache],
"fastbin_counts": [len(fb) for fb in self.fastbins],
"unsorted_count": len(self.unsorted_bin),
"top_addr": self.top_addr,
"top_size": self.top_size,
"alloc_count": self.alloc_count,
"free_count": self.free_count,
}
# ============================================================
# EXPLOIT PRIMITIVE DETECTION
# ============================================================
def check_primitives(self) -> dict:
"""Check for achieved exploit primitives."""
primitives = {
"overlapping_chunks": False,
"arbitrary_alloc": False,
"double_free": False,
"tcache_poison": False,
"freelist_cycle": False,
"duplicate_alloc": False,
}
# Check overlapping chunks: two allocated chunks whose regions overlap
alloc_chunks = [(a, c) for a, c in self.chunks.items() if c.allocated]
for i, (a1, c1) in enumerate(alloc_chunks):
for a2, c2 in alloc_chunks[i+1:]:
end1 = a1 + c1.real_size
end2 = a2 + c2.real_size
if a1 < end2 and a2 < end1:
primitives["overlapping_chunks"] = True
# Check tcache/fastbin for cycles (double free)
for bins in [self.tcache, self.fastbins]:
for bin_list in bins:
if len(bin_list) != len(set(bin_list)):
primitives["double_free"] = True
primitives["freelist_cycle"] = True
# Check for tcache poison: fd points outside heap OR fd was manually set
for tc in self.tcache:
for addr in tc:
chunk = self.chunks.get(addr)
if chunk and chunk.fd != 0:
if chunk.fd < self.heap_base or chunk.fd > self.top_addr + self.top_size:
primitives["tcache_poison"] = True
primitives["arbitrary_alloc"] = True
# Also check if fd points to an allocated chunk (shouldn't be in freelist)
fd_chunk = self.chunks.get(chunk.fd - 2 * SIZE_SZ)
if fd_chunk is None:
# fd points to user_addr of a chunk?
for a, c in self.chunks.items():
if c.user_addr == chunk.fd and c.allocated:
primitives["tcache_poison"] = True
break
# Check for duplicate allocation: two slots point to same user address
addrs = list(self.slots.values())
if len(addrs) != len(set(addrs)):
primitives["duplicate_alloc"] = True
primitives["arbitrary_alloc"] = True
return primitives
def state_to_grid(self) -> np.ndarray:
"""Convert current state to 32x16 grid for TRM input."""
grid = np.zeros((32, 16), dtype=np.int64)
state = self.get_state()
chunks = state["chunks"]
for i, c in enumerate(chunks[:32]):
# Col 0: state (1=alloc, 2=freed)
grid[i, 0] = 1 if c["allocated"] else 2
# Col 1: size class
grid[i, 1] = min(63, c["size"] >> 4)
# Col 2: prev_inuse
grid[i, 2] = 1 if c["prev_inuse"] else 0
# Col 3-4: unused flags
grid[i, 3] = 0
grid[i, 4] = 0
# Col 5: fd target (resolve to chunk index)
if c["fd"] != 0:
fd_idx = 33 # external
for j, c2 in enumerate(chunks[:32]):
if c2["addr"] == c["fd"]:
fd_idx = min(32, j + 1)
break
grid[i, 5] = fd_idx
else:
grid[i, 5] = 0
# Col 6: bk target
if c["bk"] != 0:
bk_idx = 33
for j, c2 in enumerate(chunks[:32]):
if c2["addr"] == c["bk"]:
bk_idx = min(32, j + 1)
break
grid[i, 6] = bk_idx
else:
grid[i, 6] = 0
# Col 7: bin type encoding
bin_str = c["bin"]
if "tcache" in bin_str:
grid[i, 7] = 1
elif "fastbin" in bin_str:
grid[i, 7] = 2
elif "unsorted" in bin_str:
grid[i, 7] = 3
else:
grid[i, 7] = 0
# Col 8: slot
grid[i, 8] = min(63, (c["slot"] or 0) + 1) if c["slot"] is not None else 0
# Col 9: tcache count for this size
sz = c["size"]
if sz <= 0x410:
idx = (sz - MINSIZE) // 0x10
if idx < TCACHE_MAX_BINS:
grid[i, 9] = min(63, state["tcache_counts"][idx])
# Col 10: has data
grid[i, 10] = min(63, c["user_data_len"])
# Col 11: chunk index
grid[i, 11] = min(63, i)
# Col 12: alloc_count
grid[i, 12] = min(63, state["alloc_count"])
# Col 13: free_count
grid[i, 13] = min(63, state["free_count"])
# Col 14: step
grid[i, 14] = min(63, state["step"])
# Col 15: size raw (for more granularity)
grid[i, 15] = min(63, c["size"] >> 3)
return grid