Upload cartel_block.py
Browse files- cartel_block.py +79 -0
cartel_block.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CARTEL Backbone: Hybrid SSM + RWKV + LTC block.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
try:
|
| 9 |
+
from .ssm_block import SimplifiedMambaBlock
|
| 10 |
+
from .ltc_gate import LTCGate
|
| 11 |
+
except ImportError:
|
| 12 |
+
from ssm_block import SimplifiedMambaBlock
|
| 13 |
+
from ltc_gate import LTCGate
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RWKVBlock(nn.Module):
|
| 17 |
+
"""RWKV-style block for spatial reasoning."""
|
| 18 |
+
def __init__(self, dim: int, n_head: int = 4):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.dim = dim
|
| 21 |
+
self.n_head = n_head
|
| 22 |
+
self.head_dim = dim // n_head
|
| 23 |
+
self.linear_qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 24 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 25 |
+
self.rwkv_alpha = nn.Parameter(torch.ones(n_head) * 0.5)
|
| 26 |
+
self.beta = nn.Parameter(torch.zeros(n_head))
|
| 27 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 28 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 29 |
+
self.ffn = nn.Sequential(
|
| 30 |
+
nn.Linear(dim, dim * 4),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
nn.Linear(dim * 4, dim),
|
| 33 |
+
)
|
| 34 |
+
self.time_mix = nn.Parameter(torch.ones(dim) * 0.5)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor):
|
| 37 |
+
residual = x
|
| 38 |
+
x = self.norm1(x)
|
| 39 |
+
qkv = self.linear_qkv(x)
|
| 40 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 41 |
+
B, L, D = q.shape
|
| 42 |
+
q = q.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
| 43 |
+
k = k.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
| 44 |
+
v = v.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
| 45 |
+
alpha = torch.sigmoid(self.rwkv_alpha.view(1, self.n_head, 1, 1))
|
| 46 |
+
beta = self.beta.view(1, self.n_head, 1, 1)
|
| 47 |
+
wkv = torch.zeros(B, self.n_head, 1, self.head_dim, device=x.device, dtype=x.dtype)
|
| 48 |
+
outs = []
|
| 49 |
+
for t in range(L):
|
| 50 |
+
kt = k[:, :, t:t+1, :]
|
| 51 |
+
vt = v[:, :, t:t+1, :]
|
| 52 |
+
qt = q[:, :, t:t+1, :]
|
| 53 |
+
wkv = alpha * wkv + kt.transpose(-2, -1) @ vt
|
| 54 |
+
nom = qt @ (beta * wkv.transpose(-2, -1) + kt)
|
| 55 |
+
outs.append(nom)
|
| 56 |
+
out = torch.cat(outs, dim=2)
|
| 57 |
+
out = out.transpose(1, 2).reshape(B, L, D)
|
| 58 |
+
out = self.out_proj(out)
|
| 59 |
+
x = residual + out
|
| 60 |
+
x = x + self.ffn(self.norm2(x))
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CARTELBlock(nn.Module):
|
| 65 |
+
"""One CARTEL layer = SSM + RWKV + LTC merge"""
|
| 66 |
+
def __init__(self, dim: int, d_state: int = 16, expand: int = 2):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.ssm = SimplifiedMambaBlock(dim, d_state=d_state, expand=expand)
|
| 69 |
+
self.rwkv = RWKVBlock(dim)
|
| 70 |
+
self.ltc = LTCGate(dim)
|
| 71 |
+
self.merge = nn.Linear(dim * 2, dim)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor):
|
| 74 |
+
x_ssm = self.ssm(x)
|
| 75 |
+
x_rwkv = self.rwkv(x)
|
| 76 |
+
stacked = torch.cat([x_ssm, x_rwkv], dim=-1)
|
| 77 |
+
merged = self.merge(stacked)
|
| 78 |
+
gated = self.ltc(merged)
|
| 79 |
+
return gated + x
|