| """ |
| model.py - TRM classifier for heap exploit detection. |
| |
| Standalone module with no external dependencies beyond torch + numpy. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from pathlib import Path |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.w1 = nn.Linear(dim, dim * 2, bias=False) |
| self.w2 = nn.Linear(dim, dim, bias=False) |
|
|
| def forward(self, x): |
| gate, val = self.w1(x).chunk(2, dim=-1) |
| return self.w2(F.silu(gate) * val) |
|
|
|
|
| class RecursionBlock(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.norm1 = RMSNorm(dim) |
| self.swiglu = SwiGLU(dim) |
| self.norm2 = RMSNorm(dim) |
|
|
| def forward(self, x): |
| return self.norm2(self.swiglu(self.norm1(x))) |
|
|
|
|
| class HeapTRM(nn.Module): |
| """Tiny Recursive Model for heap exploit classification.""" |
|
|
| def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512, |
| n_outer=2, n_inner=3, n_classes=2): |
| super().__init__() |
| self.n_outer = n_outer |
| self.n_inner = n_inner |
| self.embed = nn.Embedding(vocab_size, hidden_dim) |
| self.y_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.z_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.block_z = RecursionBlock(hidden_dim) |
| self.block_y = RecursionBlock(hidden_dim) |
| self.pos_embed = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02) |
| self.out_norm = RMSNorm(hidden_dim) |
| self.out_head = nn.Linear(hidden_dim, n_classes) |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| h = self.embed(x.reshape(B, -1)) + self.pos_embed |
| y = self.y_init.expand(B, -1, -1) |
| z = self.z_init.expand(B, -1, -1) |
| for _ in range(self.n_outer): |
| for _ in range(self.n_inner): |
| z = z + self.block_z(h + y + z) |
| y = y + self.block_y(y + z) |
| pooled = self.out_norm(y).mean(dim=1) |
| return self.out_head(pooled) |
|
|