File size: 2,357 Bytes
22374d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | """
model.py - TRM classifier for heap exploit detection.
Standalone module with no external dependencies beyond torch + numpy.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.w1 = nn.Linear(dim, dim * 2, bias=False)
self.w2 = nn.Linear(dim, dim, bias=False)
def forward(self, x):
gate, val = self.w1(x).chunk(2, dim=-1)
return self.w2(F.silu(gate) * val)
class RecursionBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm1 = RMSNorm(dim)
self.swiglu = SwiGLU(dim)
self.norm2 = RMSNorm(dim)
def forward(self, x):
return self.norm2(self.swiglu(self.norm1(x)))
class HeapTRM(nn.Module):
"""Tiny Recursive Model for heap exploit classification."""
def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
n_outer=2, n_inner=3, n_classes=2):
super().__init__()
self.n_outer = n_outer
self.n_inner = n_inner
self.embed = nn.Embedding(vocab_size, hidden_dim)
self.y_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.z_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.block_z = RecursionBlock(hidden_dim)
self.block_y = RecursionBlock(hidden_dim)
self.pos_embed = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
self.out_norm = RMSNorm(hidden_dim)
self.out_head = nn.Linear(hidden_dim, n_classes)
def forward(self, x):
B = x.shape[0]
h = self.embed(x.reshape(B, -1)) + self.pos_embed
y = self.y_init.expand(B, -1, -1)
z = self.z_init.expand(B, -1, -1)
for _ in range(self.n_outer):
for _ in range(self.n_inner):
z = z + self.block_z(h + y + z)
y = y + self.block_y(y + z)
pooled = self.out_norm(y).mean(dim=1)
return self.out_head(pooled)
|