amarck's picture
Add heaptrm package: v2 harness, CLI, pwntools integration, CVE tests
22374d1
"""
model.py - TRM classifier for heap exploit detection.
Standalone module with no external dependencies beyond torch + numpy.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.w1 = nn.Linear(dim, dim * 2, bias=False)
self.w2 = nn.Linear(dim, dim, bias=False)
def forward(self, x):
gate, val = self.w1(x).chunk(2, dim=-1)
return self.w2(F.silu(gate) * val)
class RecursionBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm1 = RMSNorm(dim)
self.swiglu = SwiGLU(dim)
self.norm2 = RMSNorm(dim)
def forward(self, x):
return self.norm2(self.swiglu(self.norm1(x)))
class HeapTRM(nn.Module):
"""Tiny Recursive Model for heap exploit classification."""
def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
n_outer=2, n_inner=3, n_classes=2):
super().__init__()
self.n_outer = n_outer
self.n_inner = n_inner
self.embed = nn.Embedding(vocab_size, hidden_dim)
self.y_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.z_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.block_z = RecursionBlock(hidden_dim)
self.block_y = RecursionBlock(hidden_dim)
self.pos_embed = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.out_norm = RMSNorm(hidden_dim)
self.out_head = nn.Linear(hidden_dim, n_classes)
def forward(self, x):
B = x.shape[0]
h = self.embed(x.reshape(B, -1)) + self.pos_embed
y = self.y_init.expand(B, -1, -1)
z = self.z_init.expand(B, -1, -1)
for _ in range(self.n_outer):
for _ in range(self.n_inner):
z = z + self.block_z(h + y + z)
y = y + self.block_y(y + z)
pooled = self.out_norm(y).mean(dim=1)
return self.out_head(pooled)