krystv commited on
Commit
a89ce99
·
verified ·
1 Parent(s): 0cf1113

Upload cartel_block.py

Browse files
Files changed (1) hide show
  1. cartel_block.py +79 -0
cartel_block.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CARTEL Backbone: Hybrid SSM + RWKV + LTC block.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ try:
9
+ from .ssm_block import SimplifiedMambaBlock
10
+ from .ltc_gate import LTCGate
11
+ except ImportError:
12
+ from ssm_block import SimplifiedMambaBlock
13
+ from ltc_gate import LTCGate
14
+
15
+
16
+ class RWKVBlock(nn.Module):
17
+ """RWKV-style block for spatial reasoning."""
18
+ def __init__(self, dim: int, n_head: int = 4):
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.n_head = n_head
22
+ self.head_dim = dim // n_head
23
+ self.linear_qkv = nn.Linear(dim, dim * 3, bias=False)
24
+ self.out_proj = nn.Linear(dim, dim, bias=False)
25
+ self.rwkv_alpha = nn.Parameter(torch.ones(n_head) * 0.5)
26
+ self.beta = nn.Parameter(torch.zeros(n_head))
27
+ self.norm1 = nn.LayerNorm(dim)
28
+ self.norm2 = nn.LayerNorm(dim)
29
+ self.ffn = nn.Sequential(
30
+ nn.Linear(dim, dim * 4),
31
+ nn.GELU(),
32
+ nn.Linear(dim * 4, dim),
33
+ )
34
+ self.time_mix = nn.Parameter(torch.ones(dim) * 0.5)
35
+
36
+ def forward(self, x: torch.Tensor):
37
+ residual = x
38
+ x = self.norm1(x)
39
+ qkv = self.linear_qkv(x)
40
+ q, k, v = qkv.chunk(3, dim=-1)
41
+ B, L, D = q.shape
42
+ q = q.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
43
+ k = k.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
44
+ v = v.reshape(B, L, self.n_head, self.head_dim).transpose(1, 2)
45
+ alpha = torch.sigmoid(self.rwkv_alpha.view(1, self.n_head, 1, 1))
46
+ beta = self.beta.view(1, self.n_head, 1, 1)
47
+ wkv = torch.zeros(B, self.n_head, 1, self.head_dim, device=x.device, dtype=x.dtype)
48
+ outs = []
49
+ for t in range(L):
50
+ kt = k[:, :, t:t+1, :]
51
+ vt = v[:, :, t:t+1, :]
52
+ qt = q[:, :, t:t+1, :]
53
+ wkv = alpha * wkv + kt.transpose(-2, -1) @ vt
54
+ nom = qt @ (beta * wkv.transpose(-2, -1) + kt)
55
+ outs.append(nom)
56
+ out = torch.cat(outs, dim=2)
57
+ out = out.transpose(1, 2).reshape(B, L, D)
58
+ out = self.out_proj(out)
59
+ x = residual + out
60
+ x = x + self.ffn(self.norm2(x))
61
+ return x
62
+
63
+
64
+ class CARTELBlock(nn.Module):
65
+ """One CARTEL layer = SSM + RWKV + LTC merge"""
66
+ def __init__(self, dim: int, d_state: int = 16, expand: int = 2):
67
+ super().__init__()
68
+ self.ssm = SimplifiedMambaBlock(dim, d_state=d_state, expand=expand)
69
+ self.rwkv = RWKVBlock(dim)
70
+ self.ltc = LTCGate(dim)
71
+ self.merge = nn.Linear(dim * 2, dim)
72
+
73
+ def forward(self, x: torch.Tensor):
74
+ x_ssm = self.ssm(x)
75
+ x_rwkv = self.rwkv(x)
76
+ stacked = torch.cat([x_ssm, x_rwkv], dim=-1)
77
+ merged = self.merge(stacked)
78
+ gated = self.ltc(merged)
79
+ return gated + x