File size: 5,670 Bytes
e1fdf4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self._norm(x.float()).type_as(x) * self.weight

class SSMBlock(nn.Module):
    """

    Una implementación simplificada de una capa de State Space Model (SSM) Selectivo.

    A diferencia del Transformer, esta capa tiene memoria LINEAL.

    """
    def __init__(self, dim: int, state_dim: int = 16, expand: int = 2):
        super().__init__()
        self.dim = dim
        self.state_dim = state_dim
        self.inner_dim = expand * dim
        
        # Proyecciones de entrada
        self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
        
        # Convolución 1D para capturar contexto local (estilo Mamba)
        self.conv1d = nn.Conv1d(
            in_channels=self.inner_dim,
            out_channels=self.inner_dim,
            kernel_size=4,
            groups=self.inner_dim,
            padding=3
        )
        
        # Parámetros del SSM (Simplificado)
        # Delta (dt): El paso de tiempo reactivo
        self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
        
        # matrices A y B (Simplificadas para ser selectivas)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, state_dim + 1).float().repeat(self.inner_dim, 1)))
        self.B_proj = nn.Linear(self.inner_dim, state_dim, bias=False)
        self.C_proj = nn.Linear(self.inner_dim, state_dim, bias=False)
        
        # Proyección de salida
        self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)

    def forward(self, x: torch.Tensor):
        # x: [Batch, SeqLen, Dim]
        b, l, d = x.shape
        
        # 1. Proyección inicial y split
        x_and_res = self.in_proj(x) # [B, L, 2 * InnerDim]
        x_inner, res = x_and_res.split(self.inner_dim, dim=-1)
        
        # 2. Convolución local
        x_inner = x_inner.transpose(1, 2) # [B, InnerDim, L]
        x_inner = self.conv1d(x_inner)[:, :, :l] # Recortar padding
        x_inner = x_inner.transpose(1, 2) # [B, L, InnerDim]
        x_inner = F.silu(x_inner) # Activación Swish
        
        # 3. Mecanismo SSM Selectivo (Simplificado)
        # En una implementación real (como Mamba), esto se hace con un kernel de GPU 
        # para ser ultra rápido. Aquí usamos un bucle o aproximación para entenderlo.
        
        dt = F.softplus(self.dt_proj(x_inner)) # [B, L, InnerDim]
        A = -torch.exp(self.A_log) # [InnerDim, StateDim]
        B = self.B_proj(x_inner) # [B, L, StateDim]
        C = self.C_proj(x_inner) # [B, L, StateDim]
        
        # El "Estado Oculto" del modelo (Memory)
        # Aquí es donde ocurre la magia: el estado tiene tamaño fijo [B, InnerDim, StateDim]
        state = torch.zeros(b, self.inner_dim, self.state_dim, device=x.device)
        y = torch.zeros(b, l, self.inner_dim, device=x.device)
        
        # Escaneo Secuencial (Selective Scan)
        # Esto reemplaza a la Atención del Transformer.
        for t in range(l):
            # dt_t: [B, InnerDim]
            dt_t = dt[:, t, :].unsqueeze(-1)
            # x_t: [B, InnerDim]
            x_t = x_inner[:, t, :].unsqueeze(-1)
            # B_t: [B, StateDim]
            B_t = B[:, t, :].unsqueeze(1)
            
            # Discretización (Aproximación de Euler)
            A_bar = torch.exp(A.unsqueeze(0) * dt_t) # [B, InnerDim, StateDim]
            B_bar = dt_t * B_t
            
            # Actualizar Estado: h = A*h + B*x
            state = A_bar * state + B_bar * x_t
            
            # Salida: y = C*h
            C_t = C[:, t, :].unsqueeze(-1) # [B, StateDim, 1]
            y[:, t, :] = torch.matmul(state, C_t).squeeze(-1)
        
        # 4. Combinar con el residuo y proyectar salida
        out = y * F.silu(res)
        return self.out_proj(out)

class TransformerKiller(nn.Module):
    """

    Arquitectura basada en SSM (Mamba-like).

    No usa Atención. Contexto teóricamente infinito.

    """
    def __init__(self, vocab_size: int, dim: int, n_layers: int, state_dim: int = 16):
        super().__init__()
        self.tok_embeddings = nn.Embedding(vocab_size, dim)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'norm': RMSNorm(dim),
                'ssm': SSMBlock(dim, state_dim=state_dim)
            }) for _ in range(n_layers)
        ])
        
        self.norm_f = RMSNorm(dim)
        self.output = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, tokens: torch.Tensor):
        x = self.tok_embeddings(tokens)
        
        for layer in self.layers:
            # Conexión residual + SSM
            x = x + layer['ssm'](layer['norm'](x))
            
        x = self.norm_f(x)
        return self.output(x)

if __name__ == "__main__":
    # Test de dimensiones
    model = TransformerKiller(vocab_size=100, dim=128, n_layers=4)
    test_input = torch.randint(0, 100, (2, 50))
    output = model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"SSM Output shape: {output.shape}")
    # Nota como no hay límite de BLOCK_SIZE en el forward, 
    # solo el que dicte tu RAM, pero el coste es Lineal!