LCTLM2 / LCTLM2.py
Airin-chan's picture
Upload LCTLM2.py
612b16b verified
import torch
from torch import nn
from typing import Literal
class LogitMixingAttention (nn.Module) :
def __init__ (self,embed_dim,num_head,mode:Literal['scaled','kfactor']=None,factor_init=1e-4,causal_mask=False) :
super().__init__()
self.mode = mode
if self.mode == None :
raise RuntimeError(f"the mode must be specified kfactor/scaled")
self.causal_mask = causal_mask
self.lg = nn.Linear(embed_dim,embed_dim,bias=False)
self.lo = nn.Linear(embed_dim,embed_dim,bias=False)
self.factor = nn.Parameter(
data=torch.normal(mean=0,std=factor_init,size=(1,1,embed_dim))
)
self.num_head = num_head
self.dim_k = embed_dim//num_head
def split_head (self,x : torch.Tensor) :
b,s,d = x.shape
x = torch.reshape(x,(b,s,self.num_head,self.dim_k))
return x.permute(0,2,1,3)
def scaled_attention (self,q,k,v):
b,s,d = q.shape
if self.causal_mask :
mask = 1 - torch.tril(torch.ones((s,s)),diagonal=0).to(q.device)
mask = mask * -1e9
mask = mask.unsqueeze(0)
else :
mask = None
lm = q + (self.factor * (k + v))
lm = self.lg(lm)
lm = self.split_head(lm)
k = self.split_head(k)
v = self.split_head(v)
score = torch.matmul(lm,k.transpose(-1,-2)) / self.dim_k ** 0.5
if mask is not None :
score = score + mask
attn = nn.functional.softmax(score,dim=-1)
attn = torch.matmul(attn,v)
attn = attn.permute(0,2,1,3)
attn = torch.reshape(attn,(b,s,d))
out = self.lo(attn)
return out
def kfactor_attention (self,q,k,v) :
b,s,d = q.shape
score = q + (self.factor* k ) / self.dim_k ** 0.5
score = self.split_head(score)
v = self.lg(v)
v = self.split_head(v)
score = nn.functional.softmax(score,dim=-1)
attn = (score * v) + v
attn = attn.permute(0,2,1,3)
attn = torch.reshape(attn,(b,s,d))
out = self.lo(attn)
return out
def forward(self,q,k,v) :
if self.mode == 'scaled' :
return self.scaled_attention(q,k,v)
elif self.mode == 'kfactor' :
return self.kfactor_attention(q,k,v)
else :
raise RuntimeError(f"the mode must be specified kfactor/scaled")
class LCM (nn.Module) :
def __init__ (self,embed_dim,drop_rate) :
super().__init__()
self.step1 = nn.Linear(embed_dim,embed_dim,bias=False)
self.step2 = nn.Linear(embed_dim,embed_dim,bias=False)
self.gelu1 = nn.GELU(approximate='tanh')
self.gelu2 = nn.GELU(approximate='tanh')
self.mg = nn.Linear(embed_dim,embed_dim,bias=False)
self.tanh = nn.Tanh()
self.norm = nn.RMSNorm(embed_dim)
self.drop = nn.Dropout(drop_rate)
def forward(self,x):
z = self.norm(x)
step1 = self.step1(z)
step1 = self.gelu1(step1)
step2 = self.step2(z)
step2 = self.gelu2(step2)
mx = step1 + step2
mx = self.drop(mx)
mx = self.mg(mx)
mx = self.tanh(mx)
return x + mx
class GlobalRouterBlock (nn.Module) :
def __init__ (self,embed_dim,hidden_dim,num_expert) :
super().__init__()
self.Linear1 = nn.Linear(embed_dim,hidden_dim)
self.linear2 = nn.Linear(hidden_dim,num_expert)
def forward(self,x) :
x = x[:,-1,:]
x = self.Linear1(x)
x = self.linear2(x)
return x
class TransformersBlock (nn.Module) :
def __init__ (self,embed_dim,drop_rate) :
super().__init__()
self.attention = LogitMixingAttention(embed_dim=embed_dim,num_head=embed_dim//64,mode='kfactor')
self.norm = nn.RMSNorm(embed_dim)
self.dropout = nn.Dropout(drop_rate)
self.lcm = LCM(embed_dim=embed_dim,drop_rate=drop_rate)
def forward (self,x) :
z = self.norm(x)
attn = self.attention(z,z,z)
attn = self.dropout(attn)
x = x + attn
x = self.lcm(x)
return x
class LCTLM(nn.Module) :
def __init__ (self,embed_dim,drop_rate) :
super().__init__()
self.block1 = TransformersBlock(embed_dim,drop_rate)
self.block2 = TransformersBlock(embed_dim,drop_rate)
self.block3 = TransformersBlock(embed_dim,drop_rate)
self.block4 = TransformersBlock(embed_dim,drop_rate)
def forward(self,x,idx) :
if idx == 0 :
x = self.block1(x)
elif idx == 1 :
x = self.block2(x)
elif idx == 2 :
x = self.block3(x)
else :
x = self.block4(x)
return x
class LCTLM2 (nn.Module) :
def __init__ (self,vocab_size = 30001,embed_dim=640,drop_rate=0.1,maxpos=250,temperature=1.2) :
super().__init__()
self.temperature = temperature
self.embedding = nn.Embedding(vocab_size,embed_dim)
self.pos_embedding = nn.Embedding(maxpos,embed_dim)
self.lctlm1 = LCTLM(embed_dim,drop_rate)
self.lctlm2 = LCTLM(embed_dim,drop_rate)
self.lctlm3 = LCTLM(embed_dim,drop_rate)
self.lctlm4 = LCTLM(embed_dim,drop_rate)
self.lctlm5 = LCTLM(embed_dim,drop_rate)
self.lctlm6 = LCTLM(embed_dim,drop_rate)
self.lctlm7 = LCTLM(embed_dim,drop_rate)
self.lctlm8 = LCTLM(embed_dim,drop_rate)
self.global_ffn = nn.Sequential(
nn.Linear(embed_dim,embed_dim*4,bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(embed_dim*4,embed_dim,bias=False)
)
self.routers = GlobalRouterBlock(embed_dim=embed_dim,hidden_dim=128,num_expert=8*4)
self.fn = nn.Linear(embed_dim,vocab_size,bias=False)
self.scale = embed_dim ** 0.5
def forward(self,x) :
b,s = x.shape
x = self.embedding(x)
x = x * self.scale
pos = torch.arange(s,device=x.device)
pos = self.pos_embedding(pos)
pos = pos.unsqueeze(0)
x = x + pos
r = self.routers(x)
r = r / self.temperature
_,idx = torch.topk(r,k=8)
idx = idx[0]//8
x = self.lctlm1(x,idx[0])
x = self.lctlm2(x,idx[1])
x = self.lctlm3(x,idx[2])
x = self.lctlm4(x,idx[3])
x = self.lctlm5(x,idx[4])
x = self.lctlm6(x,idx[5])
x = self.lctlm7(x,idx[6])
x = self.lctlm8(x,idx[7])
x = self.global_ffn(x)
x = self.fn(x)
return x