| import torch | |
| from torch import nn | |
| from typing import Literal | |
| class LogitMixingAttention (nn.Module) : | |
| def __init__ (self,embed_dim,num_head,mode:Literal['scaled','kfactor']=None,factor_init=1e-4,causal_mask=False) : | |
| super().__init__() | |
| self.mode = mode | |
| if self.mode == None : | |
| raise RuntimeError(f"the mode must be specified kfactor/scaled") | |
| self.causal_mask = causal_mask | |
| self.lg = nn.Linear(embed_dim,embed_dim,bias=False) | |
| self.lo = nn.Linear(embed_dim,embed_dim,bias=False) | |
| self.factor = nn.Parameter( | |
| data=torch.normal(mean=0,std=factor_init,size=(1,1,embed_dim)) | |
| ) | |
| self.num_head = num_head | |
| self.dim_k = embed_dim//num_head | |
| def split_head (self,x : torch.Tensor) : | |
| b,s,d = x.shape | |
| x = torch.reshape(x,(b,s,self.num_head,self.dim_k)) | |
| return x.permute(0,2,1,3) | |
| def scaled_attention (self,q,k,v): | |
| b,s,d = q.shape | |
| if self.causal_mask : | |
| mask = 1 - torch.tril(torch.ones((s,s)),diagonal=0).to(q.device) | |
| mask = mask * -1e9 | |
| mask = mask.unsqueeze(0) | |
| else : | |
| mask = None | |
| lm = q + (self.factor * (k + v)) | |
| lm = self.lg(lm) | |
| lm = self.split_head(lm) | |
| k = self.split_head(k) | |
| v = self.split_head(v) | |
| score = torch.matmul(lm,k.transpose(-1,-2)) / self.dim_k ** 0.5 | |
| if mask is not None : | |
| score = score + mask | |
| attn = nn.functional.softmax(score,dim=-1) | |
| attn = torch.matmul(attn,v) | |
| attn = attn.permute(0,2,1,3) | |
| attn = torch.reshape(attn,(b,s,d)) | |
| out = self.lo(attn) | |
| return out | |
| def kfactor_attention (self,q,k,v) : | |
| b,s,d = q.shape | |
| score = q + (self.factor* k ) / self.dim_k ** 0.5 | |
| score = self.split_head(score) | |
| v = self.lg(v) | |
| v = self.split_head(v) | |
| score = nn.functional.softmax(score,dim=-1) | |
| attn = (score * v) + v | |
| attn = attn.permute(0,2,1,3) | |
| attn = torch.reshape(attn,(b,s,d)) | |
| out = self.lo(attn) | |
| return out | |
| def forward(self,q,k,v) : | |
| if self.mode == 'scaled' : | |
| return self.scaled_attention(q,k,v) | |
| elif self.mode == 'kfactor' : | |
| return self.kfactor_attention(q,k,v) | |
| else : | |
| raise RuntimeError(f"the mode must be specified kfactor/scaled") | |
| class LCM (nn.Module) : | |
| def __init__ (self,embed_dim,drop_rate) : | |
| super().__init__() | |
| self.step1 = nn.Linear(embed_dim,embed_dim,bias=False) | |
| self.step2 = nn.Linear(embed_dim,embed_dim,bias=False) | |
| self.gelu1 = nn.GELU(approximate='tanh') | |
| self.gelu2 = nn.GELU(approximate='tanh') | |
| self.mg = nn.Linear(embed_dim,embed_dim,bias=False) | |
| self.tanh = nn.Tanh() | |
| self.norm = nn.RMSNorm(embed_dim) | |
| self.drop = nn.Dropout(drop_rate) | |
| def forward(self,x): | |
| z = self.norm(x) | |
| step1 = self.step1(z) | |
| step1 = self.gelu1(step1) | |
| step2 = self.step2(z) | |
| step2 = self.gelu2(step2) | |
| mx = step1 + step2 | |
| mx = self.drop(mx) | |
| mx = self.mg(mx) | |
| mx = self.tanh(mx) | |
| return x + mx | |
| class GlobalRouterBlock (nn.Module) : | |
| def __init__ (self,embed_dim,hidden_dim,num_expert) : | |
| super().__init__() | |
| self.Linear1 = nn.Linear(embed_dim,hidden_dim) | |
| self.linear2 = nn.Linear(hidden_dim,num_expert) | |
| def forward(self,x) : | |
| x = x[:,-1,:] | |
| x = self.Linear1(x) | |
| x = self.linear2(x) | |
| return x | |
| class TransformersBlock (nn.Module) : | |
| def __init__ (self,embed_dim,drop_rate) : | |
| super().__init__() | |
| self.attention = LogitMixingAttention(embed_dim=embed_dim,num_head=embed_dim//64,mode='kfactor') | |
| self.norm = nn.RMSNorm(embed_dim) | |
| self.dropout = nn.Dropout(drop_rate) | |
| self.lcm = LCM(embed_dim=embed_dim,drop_rate=drop_rate) | |
| def forward (self,x) : | |
| z = self.norm(x) | |
| attn = self.attention(z,z,z) | |
| attn = self.dropout(attn) | |
| x = x + attn | |
| x = self.lcm(x) | |
| return x | |
| class LCTLM(nn.Module) : | |
| def __init__ (self,embed_dim,drop_rate) : | |
| super().__init__() | |
| self.block1 = TransformersBlock(embed_dim,drop_rate) | |
| self.block2 = TransformersBlock(embed_dim,drop_rate) | |
| self.block3 = TransformersBlock(embed_dim,drop_rate) | |
| self.block4 = TransformersBlock(embed_dim,drop_rate) | |
| def forward(self,x,idx) : | |
| if idx == 0 : | |
| x = self.block1(x) | |
| elif idx == 1 : | |
| x = self.block2(x) | |
| elif idx == 2 : | |
| x = self.block3(x) | |
| else : | |
| x = self.block4(x) | |
| return x | |
| class LCTLM2 (nn.Module) : | |
| def __init__ (self,vocab_size = 30001,embed_dim=640,drop_rate=0.1,maxpos=250,temperature=1.2) : | |
| super().__init__() | |
| self.temperature = temperature | |
| self.embedding = nn.Embedding(vocab_size,embed_dim) | |
| self.pos_embedding = nn.Embedding(maxpos,embed_dim) | |
| self.lctlm1 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm2 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm3 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm4 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm5 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm6 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm7 = LCTLM(embed_dim,drop_rate) | |
| self.lctlm8 = LCTLM(embed_dim,drop_rate) | |
| self.global_ffn = nn.Sequential( | |
| nn.Linear(embed_dim,embed_dim*4,bias=False), | |
| nn.GELU(approximate='tanh'), | |
| nn.Linear(embed_dim*4,embed_dim,bias=False) | |
| ) | |
| self.routers = GlobalRouterBlock(embed_dim=embed_dim,hidden_dim=128,num_expert=8*4) | |
| self.fn = nn.Linear(embed_dim,vocab_size,bias=False) | |
| self.scale = embed_dim ** 0.5 | |
| def forward(self,x) : | |
| b,s = x.shape | |
| x = self.embedding(x) | |
| x = x * self.scale | |
| pos = torch.arange(s,device=x.device) | |
| pos = self.pos_embedding(pos) | |
| pos = pos.unsqueeze(0) | |
| x = x + pos | |
| r = self.routers(x) | |
| r = r / self.temperature | |
| _,idx = torch.topk(r,k=8) | |
| idx = idx[0]//8 | |
| x = self.lctlm1(x,idx[0]) | |
| x = self.lctlm2(x,idx[1]) | |
| x = self.lctlm3(x,idx[2]) | |
| x = self.lctlm4(x,idx[3]) | |
| x = self.lctlm5(x,idx[4]) | |
| x = self.lctlm6(x,idx[5]) | |
| x = self.lctlm7(x,idx[6]) | |
| x = self.lctlm8(x,idx[7]) | |
| x = self.global_ffn(x) | |
| x = self.fn(x) | |
| return x | |