""" heap_sim.py - Lightweight ptmalloc2 heap allocator simulator. Simulates the core glibc heap allocator behavior: - Chunk metadata (prev_size, size, flags, fd, bk) - Tcache (per-size LIFO, 7 entries max, sizes 0x20-0x410) - Fastbins (LIFO, sizes 0x20-0x80) - Unsorted bin (doubly-linked) - Forward/backward coalescing - Top chunk splitting - Request size -> chunk size conversion Does NOT simulate: - Safe linking (we want to model the exploitable state) - mmap'd chunks - Thread arenas - Large bin sorting This is deliberately simplified to be fast enough for self-play training. """ import numpy as np from typing import Optional, List, Dict, Tuple from dataclasses import dataclass, field # Constants (64-bit) SIZE_SZ = 8 MALLOC_ALIGN_MASK = 0xF MINSIZE = 0x20 TCACHE_MAX_BINS = 64 TCACHE_FILL_COUNT = 7 MAX_FASTBIN_SIZE = 0x80 NUM_FASTBINS = 7 # sizes 0x20, 0x30, ..., 0x80 INITIAL_HEAP_BASE = 0x555555559000 TOP_CHUNK_INITIAL_SIZE = 0x21000 # Flags PREV_INUSE = 0x1 IS_MMAPPED = 0x2 NON_MAIN_ARENA = 0x4 def request_to_chunk_size(req: int) -> int: """Convert malloc request size to actual chunk size.""" if req + SIZE_SZ + MALLOC_ALIGN_MASK < MINSIZE: return MINSIZE return (req + SIZE_SZ + MALLOC_ALIGN_MASK) & ~MALLOC_ALIGN_MASK def tcache_idx(chunk_size: int) -> int: """Tcache bin index for a chunk size.""" return (chunk_size - MINSIZE) // 0x10 def fastbin_idx(chunk_size: int) -> int: """Fastbin index for a chunk size.""" return (chunk_size - MINSIZE) // 0x10 @dataclass class Chunk: """Represents a heap chunk.""" addr: int # address of chunk start (before user data) prev_size: int = 0 size: int = 0 # includes flags in low 3 bits allocated: bool = True fd: int = 0 # forward pointer (freed chunks) bk: int = 0 # backward pointer (freed chunks) user_data: bytes = b'' # first N bytes of user data @property def real_size(self) -> int: return self.size & ~0x7 @property def prev_inuse(self) -> bool: return bool(self.size & PREV_INUSE) @property def user_addr(self) -> int: return self.addr + 2 * SIZE_SZ # skip prev_size + size @property def next_chunk_addr(self) -> int: return self.addr + self.real_size class HeapSimulator: """Simulates the glibc ptmalloc2 heap allocator.""" def __init__(self, heap_base: int = INITIAL_HEAP_BASE): self.heap_base = heap_base self.chunks: Dict[int, Chunk] = {} # addr -> Chunk # Tcache: list of chunk addrs per size class self.tcache: List[List[int]] = [[] for _ in range(TCACHE_MAX_BINS)] # Fastbins: list of chunk addrs per size class self.fastbins: List[List[int]] = [[] for _ in range(NUM_FASTBINS)] # Unsorted bin: list of chunk addrs self.unsorted_bin: List[int] = [] # Top chunk self.top_addr = heap_base self.top_size = TOP_CHUNK_INITIAL_SIZE # Tracking self.alloc_count = 0 self.free_count = 0 self.step = 0 self.history: List[dict] = [] # User-facing slots (like CTF menu: slots 0-15) self.slots: Dict[int, int] = {} # slot_idx -> user_addr def _record(self, op: str, **kwargs): self.history.append({"step": self.step, "op": op, **kwargs}) self.step += 1 # ============================================================ # MALLOC # ============================================================ def malloc(self, req_size: int, slot: Optional[int] = None) -> Optional[int]: """Allocate a chunk. Returns user pointer or None.""" chunk_size = request_to_chunk_size(req_size) user_addr = None # 1. Try tcache if chunk_size <= 0x410: idx = tcache_idx(chunk_size) if idx < TCACHE_MAX_BINS and self.tcache[idx]: addr = self.tcache[idx].pop() chunk = self.chunks[addr] chunk.allocated = True chunk.fd = 0 chunk.bk = 0 user_addr = chunk.user_addr self._set_next_prev_inuse(chunk, True) # 2. Try fastbin if user_addr is None and chunk_size <= MAX_FASTBIN_SIZE: idx = fastbin_idx(chunk_size) if idx < NUM_FASTBINS and self.fastbins[idx]: addr = self.fastbins[idx].pop() chunk = self.chunks[addr] chunk.allocated = True chunk.fd = 0 chunk.bk = 0 user_addr = chunk.user_addr self._set_next_prev_inuse(chunk, True) # 3. Try unsorted bin (first fit) if user_addr is None: for i, addr in enumerate(self.unsorted_bin): chunk = self.chunks[addr] if chunk.real_size >= chunk_size: self.unsorted_bin.pop(i) remainder = chunk.real_size - chunk_size if remainder >= MINSIZE: # Split self._split_chunk(chunk, chunk_size) chunk.allocated = True chunk.size = chunk_size | PREV_INUSE chunk.fd = 0 chunk.bk = 0 user_addr = chunk.user_addr self._set_next_prev_inuse(chunk, True) break # 4. Split top chunk if user_addr is None: if self.top_size >= chunk_size + MINSIZE: addr = self.top_addr chunk = Chunk( addr=addr, size=chunk_size | PREV_INUSE, allocated=True, ) self.chunks[addr] = chunk self.top_addr = addr + chunk_size self.top_size -= chunk_size user_addr = chunk.user_addr else: return None # OOM if slot is not None and user_addr is not None: self.slots[slot] = user_addr self.alloc_count += 1 self._record("malloc", size=req_size, chunk_size=chunk_size, user_addr=user_addr, slot=slot) return user_addr # ============================================================ # FREE # ============================================================ def free(self, user_addr: int = 0, slot: Optional[int] = None) -> bool: """Free a chunk by user pointer or slot. Returns success.""" if slot is not None: user_addr = self.slots.get(slot, 0) if not user_addr: return False chunk_addr = user_addr - 2 * SIZE_SZ chunk = self.chunks.get(chunk_addr) if chunk is None: return False chunk_size = chunk.real_size # 1. Try tcache if chunk_size <= 0x410: idx = tcache_idx(chunk_size) if idx < TCACHE_MAX_BINS and len(self.tcache[idx]) < TCACHE_FILL_COUNT: chunk.allocated = False chunk.fd = self.tcache[idx][-1] if self.tcache[idx] else 0 chunk.bk = 0 self.tcache[idx].append(chunk_addr) self.free_count += 1 self._record("free", user_addr=user_addr, bin="tcache", chunk_size=chunk_size, slot=slot) return True # 2. Fastbin if chunk_size <= MAX_FASTBIN_SIZE: idx = fastbin_idx(chunk_size) if idx < NUM_FASTBINS: chunk.allocated = False chunk.fd = self.fastbins[idx][-1] if self.fastbins[idx] else 0 self.fastbins[idx].append(chunk_addr) self.free_count += 1 self._record("free", user_addr=user_addr, bin="fastbin", chunk_size=chunk_size, slot=slot) return True # 3. Coalesce and put in unsorted bin chunk.allocated = False self._coalesce_and_unsort(chunk) self.free_count += 1 self._record("free", user_addr=user_addr, bin="unsorted", chunk_size=chunk_size, slot=slot) return True # ============================================================ # WRITE (simulate vulnerability) # ============================================================ def write(self, user_addr: int, data: bytes, overflow: int = 0) -> bool: """Write data to a chunk. overflow > 0 allows OOB write.""" chunk_addr = user_addr - 2 * SIZE_SZ chunk = self.chunks.get(chunk_addr) if chunk is None: return False chunk.user_data = data # Simulate overflow: corrupt next chunk's metadata if overflow > 0: next_addr = chunk.next_chunk_addr next_chunk = self.chunks.get(next_addr) if next_chunk and overflow >= 1: # Off-by-one null byte: corrupt size field's LSB next_chunk.size = next_chunk.size & ~0xFF # More overflow: can corrupt fd/bk if overflow >= SIZE_SZ and not next_chunk.allocated: # Overwrite fd pointer if len(data) > chunk.real_size - 2 * SIZE_SZ: overflow_data = data[chunk.real_size - 2 * SIZE_SZ:] if len(overflow_data) >= 8: next_chunk.fd = int.from_bytes( overflow_data[:8], 'little') self._record("write", user_addr=user_addr, data_len=len(data), overflow=overflow) return True def write_to_freed(self, user_addr: int, fd_value: int) -> bool: """UAF: write fd pointer of a freed chunk.""" chunk_addr = user_addr - 2 * SIZE_SZ chunk = self.chunks.get(chunk_addr) if chunk is None or chunk.allocated: return False chunk.fd = fd_value self._record("write_freed", user_addr=user_addr, fd_value=fd_value) return True # ============================================================ # INTERNAL HELPERS # ============================================================ def _set_next_prev_inuse(self, chunk: Chunk, inuse: bool): """Set the PREV_INUSE bit of the chunk following `chunk`.""" next_addr = chunk.next_chunk_addr next_chunk = self.chunks.get(next_addr) if next_chunk: if inuse: next_chunk.size |= PREV_INUSE else: next_chunk.size &= ~PREV_INUSE next_chunk.prev_size = chunk.real_size def _split_chunk(self, chunk: Chunk, new_size: int): """Split chunk into (new_size) and remainder.""" remainder_addr = chunk.addr + new_size remainder_size = chunk.real_size - new_size remainder = Chunk( addr=remainder_addr, prev_size=new_size, size=remainder_size | PREV_INUSE, allocated=False, ) self.chunks[remainder_addr] = remainder self.unsorted_bin.append(remainder_addr) def _coalesce_and_unsort(self, chunk: Chunk): """Coalesce with neighbors and put in unsorted bin.""" addr = chunk.addr size = chunk.real_size # Forward coalesce: merge with next chunk if free next_addr = addr + size next_chunk = self.chunks.get(next_addr) if next_chunk and not next_chunk.allocated: # Remove next from whatever bin it's in self._remove_from_bins(next_addr) size += next_chunk.real_size del self.chunks[next_addr] # Backward coalesce: merge with prev chunk if free if not chunk.prev_inuse and chunk.prev_size > 0: prev_addr = addr - chunk.prev_size prev_chunk = self.chunks.get(prev_addr) if prev_chunk and not prev_chunk.allocated: self._remove_from_bins(prev_addr) size += prev_chunk.real_size del self.chunks[addr] addr = prev_addr chunk = prev_chunk # Update chunk chunk.addr = addr chunk.size = size | PREV_INUSE # prev of coalesced is always inuse chunk.allocated = False self.chunks[addr] = chunk # Check if coalescing into top if addr + size == self.top_addr: self.top_addr = addr self.top_size += size if addr in self.chunks: del self.chunks[addr] return self.unsorted_bin.append(addr) self._set_next_prev_inuse(chunk, False) def _remove_from_bins(self, addr: int): """Remove a chunk addr from whichever bin it's in.""" for tc in self.tcache: if addr in tc: tc.remove(addr) return for fb in self.fastbins: if addr in fb: fb.remove(addr) return if addr in self.unsorted_bin: self.unsorted_bin.remove(addr) # ============================================================ # STATE OBSERVATION # ============================================================ def get_state(self) -> dict: """Return current heap state as a dict (for grid encoding).""" chunks_info = [] for addr in sorted(self.chunks.keys()): c = self.chunks[addr] # Determine which bin it's in bin_type = "none" if not c.allocated: for i, tc in enumerate(self.tcache): if addr in tc: bin_type = f"tcache_{i}" break else: for i, fb in enumerate(self.fastbins): if addr in fb: bin_type = f"fastbin_{i}" break else: if addr in self.unsorted_bin: bin_type = "unsorted" # Which slot points here? slot = None for s, ua in self.slots.items(): if ua == c.user_addr: slot = s break chunks_info.append({ "addr": addr, "size": c.real_size, "allocated": c.allocated, "prev_inuse": c.prev_inuse, "fd": c.fd, "bk": c.bk, "bin": bin_type, "slot": slot, "user_data_len": len(c.user_data), }) return { "step": self.step, "n_chunks": len(self.chunks), "chunks": chunks_info, "tcache_counts": [len(tc) for tc in self.tcache], "fastbin_counts": [len(fb) for fb in self.fastbins], "unsorted_count": len(self.unsorted_bin), "top_addr": self.top_addr, "top_size": self.top_size, "alloc_count": self.alloc_count, "free_count": self.free_count, } # ============================================================ # EXPLOIT PRIMITIVE DETECTION # ============================================================ def check_primitives(self) -> dict: """Check for achieved exploit primitives.""" primitives = { "overlapping_chunks": False, "arbitrary_alloc": False, "double_free": False, "tcache_poison": False, "freelist_cycle": False, "duplicate_alloc": False, } # Check overlapping chunks: two allocated chunks whose regions overlap alloc_chunks = [(a, c) for a, c in self.chunks.items() if c.allocated] for i, (a1, c1) in enumerate(alloc_chunks): for a2, c2 in alloc_chunks[i+1:]: end1 = a1 + c1.real_size end2 = a2 + c2.real_size if a1 < end2 and a2 < end1: primitives["overlapping_chunks"] = True # Check tcache/fastbin for cycles (double free) for bins in [self.tcache, self.fastbins]: for bin_list in bins: if len(bin_list) != len(set(bin_list)): primitives["double_free"] = True primitives["freelist_cycle"] = True # Check for tcache poison: fd points outside heap OR fd was manually set for tc in self.tcache: for addr in tc: chunk = self.chunks.get(addr) if chunk and chunk.fd != 0: if chunk.fd < self.heap_base or chunk.fd > self.top_addr + self.top_size: primitives["tcache_poison"] = True primitives["arbitrary_alloc"] = True # Also check if fd points to an allocated chunk (shouldn't be in freelist) fd_chunk = self.chunks.get(chunk.fd - 2 * SIZE_SZ) if fd_chunk is None: # fd points to user_addr of a chunk? for a, c in self.chunks.items(): if c.user_addr == chunk.fd and c.allocated: primitives["tcache_poison"] = True break # Check for duplicate allocation: two slots point to same user address addrs = list(self.slots.values()) if len(addrs) != len(set(addrs)): primitives["duplicate_alloc"] = True primitives["arbitrary_alloc"] = True return primitives def state_to_grid(self) -> np.ndarray: """Convert current state to 32x16 grid for TRM input.""" grid = np.zeros((32, 16), dtype=np.int64) state = self.get_state() chunks = state["chunks"] for i, c in enumerate(chunks[:32]): # Col 0: state (1=alloc, 2=freed) grid[i, 0] = 1 if c["allocated"] else 2 # Col 1: size class grid[i, 1] = min(63, c["size"] >> 4) # Col 2: prev_inuse grid[i, 2] = 1 if c["prev_inuse"] else 0 # Col 3-4: unused flags grid[i, 3] = 0 grid[i, 4] = 0 # Col 5: fd target (resolve to chunk index) if c["fd"] != 0: fd_idx = 33 # external for j, c2 in enumerate(chunks[:32]): if c2["addr"] == c["fd"]: fd_idx = min(32, j + 1) break grid[i, 5] = fd_idx else: grid[i, 5] = 0 # Col 6: bk target if c["bk"] != 0: bk_idx = 33 for j, c2 in enumerate(chunks[:32]): if c2["addr"] == c["bk"]: bk_idx = min(32, j + 1) break grid[i, 6] = bk_idx else: grid[i, 6] = 0 # Col 7: bin type encoding bin_str = c["bin"] if "tcache" in bin_str: grid[i, 7] = 1 elif "fastbin" in bin_str: grid[i, 7] = 2 elif "unsorted" in bin_str: grid[i, 7] = 3 else: grid[i, 7] = 0 # Col 8: slot grid[i, 8] = min(63, (c["slot"] or 0) + 1) if c["slot"] is not None else 0 # Col 9: tcache count for this size sz = c["size"] if sz <= 0x410: idx = (sz - MINSIZE) // 0x10 if idx < TCACHE_MAX_BINS: grid[i, 9] = min(63, state["tcache_counts"][idx]) # Col 10: has data grid[i, 10] = min(63, c["user_data_len"]) # Col 11: chunk index grid[i, 11] = min(63, i) # Col 12: alloc_count grid[i, 12] = min(63, state["alloc_count"]) # Col 13: free_count grid[i, 13] = min(63, state["free_count"]) # Col 14: step grid[i, 14] = min(63, state["step"]) # Col 15: size raw (for more granularity) grid[i, 15] = min(63, c["size"] >> 3) return grid