Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from .usta_decoder_block import UstaDecoderBlock | |
| from .usta_embedding import UstaEmbedding | |
| class UstaModel(nn.Module): | |
| def __init__(self,vocab_size,embedding_dim,num_heads,context_length,num_layers,device): | |
| super().__init__() | |
| self.embedding = UstaEmbedding(vocab_size,embedding_dim,context_length,device) | |
| self.layers = nn.Sequential(*[UstaDecoderBlock(embedding_dim,num_heads,context_length,device) for _ in range(num_layers)]) | |
| self.lm_head = nn.Linear(embedding_dim,vocab_size,device=device) | |
| self.device = device | |
| def forward(self,x:torch.Tensor): | |
| x = self.embedding(x) # dictionary meaning of the tokens (words) | |
| x = self.layers(x) | |
| x = self.lm_head(x) | |
| return x | |
| def top_p_filtering(self,logits,top_p): | |
| sorted_logits,sorted_indices = torch.sort(logits,descending = True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits,dim=-1),dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[...,1:] = sorted_indices_to_remove[...,:-1].clone() | |
| sorted_indices_to_remove[...,0] = False | |
| sorted_logits[sorted_indices_to_remove] = -float('inf') | |
| filtered_logits = sorted_logits.clone() | |
| filtered_logits.scatter_(0,sorted_indices,sorted_logits) | |
| return filtered_logits | |
| def generate(self, | |
| x:torch.Tensor, | |
| max_new_tokens:int=3, | |
| temperature:float = 1.0, | |
| top_k:int=64, | |
| top_p:float=1.0 | |
| ): #top_k,top_p temperature | |
| tokens = x.tolist() | |
| for _ in range(max_new_tokens): | |
| x = x.unsqueeze(0).to(self.device) | |
| out = self.forward(x) | |
| out = out.squeeze(0) | |
| logits = out[-1] | |
| if top_k > 0: | |
| values,indexes = torch.topk(logits,k=top_k) | |
| logits = torch.full_like(logits,-float('inf')) | |
| logits.scatter_(0,indexes,values) | |
| if top_p > 0 and top_p < 1: | |
| logits = self.top_p_filtering(logits,top_p) | |
| if temperature != 1.0 and temperature > 0.0: | |
| logits = logits/temperature | |
| probs = torch.softmax(values,dim=-1) | |
| sample = torch.multinomial(probs,1) | |
| max_index = indexes[sample] | |
| tokens.append(max_index.item()) | |
| if max_index == 59 or len(tokens) > 32: # end of sentence token or context length | |
| break | |
| x = torch.tensor(tokens) | |
| return tokens | |