llm-from-scratch / v2 /usta_model.py
Celal Şamil Kartoğlu
v2 implemented
17c0b30
import torch
import torch.nn as nn
from .usta_decoder_block import UstaDecoderBlock
from .usta_embedding import UstaEmbedding
class UstaModel(nn.Module):
def __init__(self,vocab_size,embedding_dim,num_heads,context_length,num_layers,device):
super().__init__()
self.embedding = UstaEmbedding(vocab_size,embedding_dim,context_length,device)
self.layers = nn.Sequential(*[UstaDecoderBlock(embedding_dim,num_heads,context_length,device) for _ in range(num_layers)])
self.lm_head = nn.Linear(embedding_dim,vocab_size,device=device)
self.device = device
def forward(self,x:torch.Tensor):
x = self.embedding(x) # dictionary meaning of the tokens (words)
x = self.layers(x)
x = self.lm_head(x)
return x
def top_p_filtering(self,logits,top_p):
sorted_logits,sorted_indices = torch.sort(logits,descending = True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits,dim=-1),dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[...,1:] = sorted_indices_to_remove[...,:-1].clone()
sorted_indices_to_remove[...,0] = False
sorted_logits[sorted_indices_to_remove] = -float('inf')
filtered_logits = sorted_logits.clone()
filtered_logits.scatter_(0,sorted_indices,sorted_logits)
return filtered_logits
def generate(self,
x:torch.Tensor,
max_new_tokens:int=3,
temperature:float = 1.0,
top_k:int=64,
top_p:float=1.0
): #top_k,top_p temperature
tokens = x.tolist()
for _ in range(max_new_tokens):
x = x.unsqueeze(0).to(self.device)
out = self.forward(x)
out = out.squeeze(0)
logits = out[-1]
if top_k > 0:
values,indexes = torch.topk(logits,k=top_k)
logits = torch.full_like(logits,-float('inf'))
logits.scatter_(0,indexes,values)
if top_p > 0 and top_p < 1:
logits = self.top_p_filtering(logits,top_p)
if temperature != 1.0 and temperature > 0.0:
logits = logits/temperature
probs = torch.softmax(values,dim=-1)
sample = torch.multinomial(probs,1)
max_index = indexes[sample]
tokens.append(max_index.item())
if max_index == 59 or len(tokens) > 32: # end of sentence token or context length
break
x = torch.tensor(tokens)
return tokens