ULFBERTO commited on
Commit
e1fdf4e
·
verified ·
1 Parent(s): 5348759

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +144 -0
model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ import math
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dim: int, eps: float = 1e-6):
9
+ super().__init__()
10
+ self.eps = eps
11
+ self.weight = nn.Parameter(torch.ones(dim))
12
+
13
+ def _norm(self, x):
14
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
15
+
16
+ def forward(self, x):
17
+ return self._norm(x.float()).type_as(x) * self.weight
18
+
19
+ class SSMBlock(nn.Module):
20
+ """
21
+ Una implementación simplificada de una capa de State Space Model (SSM) Selectivo.
22
+ A diferencia del Transformer, esta capa tiene memoria LINEAL.
23
+ """
24
+ def __init__(self, dim: int, state_dim: int = 16, expand: int = 2):
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.state_dim = state_dim
28
+ self.inner_dim = expand * dim
29
+
30
+ # Proyecciones de entrada
31
+ self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
32
+
33
+ # Convolución 1D para capturar contexto local (estilo Mamba)
34
+ self.conv1d = nn.Conv1d(
35
+ in_channels=self.inner_dim,
36
+ out_channels=self.inner_dim,
37
+ kernel_size=4,
38
+ groups=self.inner_dim,
39
+ padding=3
40
+ )
41
+
42
+ # Parámetros del SSM (Simplificado)
43
+ # Delta (dt): El paso de tiempo reactivo
44
+ self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
45
+
46
+ # matrices A y B (Simplificadas para ser selectivas)
47
+ self.A_log = nn.Parameter(torch.log(torch.arange(1, state_dim + 1).float().repeat(self.inner_dim, 1)))
48
+ self.B_proj = nn.Linear(self.inner_dim, state_dim, bias=False)
49
+ self.C_proj = nn.Linear(self.inner_dim, state_dim, bias=False)
50
+
51
+ # Proyección de salida
52
+ self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ # x: [Batch, SeqLen, Dim]
56
+ b, l, d = x.shape
57
+
58
+ # 1. Proyección inicial y split
59
+ x_and_res = self.in_proj(x) # [B, L, 2 * InnerDim]
60
+ x_inner, res = x_and_res.split(self.inner_dim, dim=-1)
61
+
62
+ # 2. Convolución local
63
+ x_inner = x_inner.transpose(1, 2) # [B, InnerDim, L]
64
+ x_inner = self.conv1d(x_inner)[:, :, :l] # Recortar padding
65
+ x_inner = x_inner.transpose(1, 2) # [B, L, InnerDim]
66
+ x_inner = F.silu(x_inner) # Activación Swish
67
+
68
+ # 3. Mecanismo SSM Selectivo (Simplificado)
69
+ # En una implementación real (como Mamba), esto se hace con un kernel de GPU
70
+ # para ser ultra rápido. Aquí usamos un bucle o aproximación para entenderlo.
71
+
72
+ dt = F.softplus(self.dt_proj(x_inner)) # [B, L, InnerDim]
73
+ A = -torch.exp(self.A_log) # [InnerDim, StateDim]
74
+ B = self.B_proj(x_inner) # [B, L, StateDim]
75
+ C = self.C_proj(x_inner) # [B, L, StateDim]
76
+
77
+ # El "Estado Oculto" del modelo (Memory)
78
+ # Aquí es donde ocurre la magia: el estado tiene tamaño fijo [B, InnerDim, StateDim]
79
+ state = torch.zeros(b, self.inner_dim, self.state_dim, device=x.device)
80
+ y = torch.zeros(b, l, self.inner_dim, device=x.device)
81
+
82
+ # Escaneo Secuencial (Selective Scan)
83
+ # Esto reemplaza a la Atención del Transformer.
84
+ for t in range(l):
85
+ # dt_t: [B, InnerDim]
86
+ dt_t = dt[:, t, :].unsqueeze(-1)
87
+ # x_t: [B, InnerDim]
88
+ x_t = x_inner[:, t, :].unsqueeze(-1)
89
+ # B_t: [B, StateDim]
90
+ B_t = B[:, t, :].unsqueeze(1)
91
+
92
+ # Discretización (Aproximación de Euler)
93
+ A_bar = torch.exp(A.unsqueeze(0) * dt_t) # [B, InnerDim, StateDim]
94
+ B_bar = dt_t * B_t
95
+
96
+ # Actualizar Estado: h = A*h + B*x
97
+ state = A_bar * state + B_bar * x_t
98
+
99
+ # Salida: y = C*h
100
+ C_t = C[:, t, :].unsqueeze(-1) # [B, StateDim, 1]
101
+ y[:, t, :] = torch.matmul(state, C_t).squeeze(-1)
102
+
103
+ # 4. Combinar con el residuo y proyectar salida
104
+ out = y * F.silu(res)
105
+ return self.out_proj(out)
106
+
107
+ class TransformerKiller(nn.Module):
108
+ """
109
+ Arquitectura basada en SSM (Mamba-like).
110
+ No usa Atención. Contexto teóricamente infinito.
111
+ """
112
+ def __init__(self, vocab_size: int, dim: int, n_layers: int, state_dim: int = 16):
113
+ super().__init__()
114
+ self.tok_embeddings = nn.Embedding(vocab_size, dim)
115
+
116
+ self.layers = nn.ModuleList([
117
+ nn.ModuleDict({
118
+ 'norm': RMSNorm(dim),
119
+ 'ssm': SSMBlock(dim, state_dim=state_dim)
120
+ }) for _ in range(n_layers)
121
+ ])
122
+
123
+ self.norm_f = RMSNorm(dim)
124
+ self.output = nn.Linear(dim, vocab_size, bias=False)
125
+
126
+ def forward(self, tokens: torch.Tensor):
127
+ x = self.tok_embeddings(tokens)
128
+
129
+ for layer in self.layers:
130
+ # Conexión residual + SSM
131
+ x = x + layer['ssm'](layer['norm'](x))
132
+
133
+ x = self.norm_f(x)
134
+ return self.output(x)
135
+
136
+ if __name__ == "__main__":
137
+ # Test de dimensiones
138
+ model = TransformerKiller(vocab_size=100, dim=128, n_layers=4)
139
+ test_input = torch.randint(0, 100, (2, 50))
140
+ output = model(test_input)
141
+ print(f"Input shape: {test_input.shape}")
142
+ print(f"SSM Output shape: {output.shape}")
143
+ # Nota como no hay límite de BLOCK_SIZE en el forward,
144
+ # solo el que dicte tu RAM, pero el coste es Lineal!