""" model.py - TRM classifier for heap exploit detection. Standalone module with no external dependencies beyond torch + numpy. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from pathlib import Path class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight class SwiGLU(nn.Module): def __init__(self, dim): super().__init__() self.w1 = nn.Linear(dim, dim * 2, bias=False) self.w2 = nn.Linear(dim, dim, bias=False) def forward(self, x): gate, val = self.w1(x).chunk(2, dim=-1) return self.w2(F.silu(gate) * val) class RecursionBlock(nn.Module): def __init__(self, dim): super().__init__() self.norm1 = RMSNorm(dim) self.swiglu = SwiGLU(dim) self.norm2 = RMSNorm(dim) def forward(self, x): return self.norm2(self.swiglu(self.norm1(x))) class HeapTRM(nn.Module): """Tiny Recursive Model for heap exploit classification.""" def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, n_outer=2, n_inner=3, n_classes=2): super().__init__() self.n_outer = n_outer self.n_inner = n_inner self.embed = nn.Embedding(vocab_size, hidden_dim) self.y_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.z_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.block_z = RecursionBlock(hidden_dim) self.block_y = RecursionBlock(hidden_dim) self.pos_embed = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) self.out_norm = RMSNorm(hidden_dim) self.out_head = nn.Linear(hidden_dim, n_classes) def forward(self, x): B = x.shape[0] h = self.embed(x.reshape(B, -1)) + self.pos_embed y = self.y_init.expand(B, -1, -1) z = self.z_init.expand(B, -1, -1) for _ in range(self.n_outer): for _ in range(self.n_inner): z = z + self.block_z(h + y + z) y = y + self.block_y(y + z) pooled = self.out_norm(y).mean(dim=1) return self.out_head(pooled)