File size: 2,357 Bytes
22374d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
model.py - TRM classifier for heap exploit detection.

Standalone module with no external dependencies beyond torch + numpy.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight


class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.w1 = nn.Linear(dim, dim * 2, bias=False)
        self.w2 = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        gate, val = self.w1(x).chunk(2, dim=-1)
        return self.w2(F.silu(gate) * val)


class RecursionBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.swiglu = SwiGLU(dim)
        self.norm2 = RMSNorm(dim)

    def forward(self, x):
        return self.norm2(self.swiglu(self.norm1(x)))


class HeapTRM(nn.Module):
    """Tiny Recursive Model for heap exploit classification."""

    def __init__(self, vocab_size=64, hidden_dim=128, seq_len=512,
                 n_outer=2, n_inner=3, n_classes=2):
        super().__init__()
        self.n_outer = n_outer
        self.n_inner = n_inner
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.y_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.z_init = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.block_z = RecursionBlock(hidden_dim)
        self.block_y = RecursionBlock(hidden_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        self.out_norm = RMSNorm(hidden_dim)
        self.out_head = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        B = x.shape[0]
        h = self.embed(x.reshape(B, -1)) + self.pos_embed
        y = self.y_init.expand(B, -1, -1)
        z = self.z_init.expand(B, -1, -1)
        for _ in range(self.n_outer):
            for _ in range(self.n_inner):
                z = z + self.block_z(h + y + z)
                y = y + self.block_y(y + z)
        pooled = self.out_norm(y).mean(dim=1)
        return self.out_head(pooled)