LCTLM1 / lctlm1.py
Airin-chan's picture
Upload 2 files
548f078 verified
import torch
from torch import nn
from typing import Optional
import torch.nn.functional as F
from tokenizers import Tokenizer
class LCMBlock (nn.Module) :
"""
LCm (Laten Connected Model ) block, looking attention as two preception and icreasing it
to N multiple magnitude values.
"""
def __init__ (self,d_model :int, drop_rate : float = 0.1) :
"""
args:
d_model : int
dimention of model
drop_rate : float
rate of dropout mechanism
"""
super().__init__()
self.step1 = nn.Linear(d_model,d_model)
self.step2 = nn.Linear(d_model,d_model)
self.magnitude = nn.Linear(d_model,d_model)
self.drop = nn.Dropout(drop_rate)
self.gelu1 = nn.GELU(approximate='tanh')
self.gelu2 = nn.GELU(approximate='tanh')
self.tanh = nn.Tanh()
self.norm = nn.LayerNorm(d_model)
def forward(self,x) :
normx = self.norm(x)
step1 = self.step1(normx)
step1 = self.gelu1(step1)
step2 = self.step2(normx)
step2 = self.gelu2(step2)
laten = step1 + step2
laten = self.drop(laten)
laten = self.magnitude(laten)
laten = self.tanh(laten)
return x + laten
class LMLCTBlock (nn.Module) :
def __init__ (self,d_model,drop_rate) :
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim=d_model,num_heads=8,dropout=drop_rate,batch_first=True)
self.norm = nn.LayerNorm(d_model)
self.lcmblock = LCMBlock(d_model,drop_rate)
def forward(self,x,mask) :
normx = self.norm(x)
attention,_ = self.attention(normx,normx,normx,attn_mask=mask)
x = x + attention
x = self.lcmblock(x)
return x
import math
class LMLCT1(nn.Module):
def __init__(self, d_model=512, vocab_size=30001, num_layers=6, drop_rate=0.1, maxpos=500):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_embedding = nn.Embedding(maxpos, d_model)
self.scale = math.sqrt(d_model)
self.decoder_mlp = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.GELU(approximate='tanh'),
nn.Linear(d_model*4, d_model),
)
self.layers = nn.ModuleList([LMLCTBlock(d_model, drop_rate) for _ in range(num_layers)])
self.out = nn.Linear(d_model, vocab_size)
mask = torch.triu(torch.ones(maxpos, maxpos), diagonal=1).bool()
self.register_buffer("causal_mask", mask)
def forward(self, x):
B, S = x.size()
pos_idx = torch.arange(S, device=x.device)
x = self.embedding(x) * self.scale
pos = self.pos_embedding(pos_idx).unsqueeze(0)
x = x + pos
mask = self.causal_mask[:S, :S]
for layer in self.layers:
x = layer(x, mask=mask)
x = self.decoder_mlp(x)
logits = self.out(x)
return logits
def generate_tesk(model: LMLCT1, texts: str, tokenizer: Tokenizer, temperature: float = 1.0):
texts = "sos " + texts
input_ids = tokenizer.encode(texts).ids
start_index = len(input_ids)
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
zeros_tolerant = 0
response_tokens = []
for _ in range(start_index,500):
if zeros_tolerant >= 3 :
break
with torch.no_grad():
logits = model(input_ids)
logits = logits[:, -1, :]
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token == 0 :
zeros_tolerant +=1
next_token_id = next_token.item()
response_tokens.append(next_token_id)
input_ids = torch.cat([input_ids, next_token], dim=1)
return tokenizer.decode(response_tokens).strip()