Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .transformer_block import TransformerBlock | |
| from .config import Config | |
| class PotterGPT(nn.Module): | |
| def __init__(self,Config): | |
| super().__init__() | |
| self.n_embed = Config.n_embed | |
| self.block_size = Config.block_size | |
| self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed) | |
| self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed) | |
| self.blocks = nn.Sequential( | |
| *[TransformerBlock(Config)]*Config.n_layers, | |
| nn.LayerNorm(self.n_embed) | |
| ) | |
| self.lm_head = nn.Linear(self.n_embed,Config.vocab_size) | |
| def forward(self,idx): | |
| B,T = idx.shape | |
| token_embs = self.token_embedding_table(idx) | |
| pos_embs = self.pos_embedding_table(torch.arange(T,device=Config.device)) | |
| x = token_embs + pos_embs | |
| x = self.blocks(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| def generate(self,idx,total): | |
| for _ in range(total): | |
| idx_cond = idx[:, -self.block_size:] | |
| logits= self(idx_cond) | |
| logits = logits[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx |