| import torch | |
| import torch.nn as nn | |
| class SimpleGPT(nn.Module): | |
| def __init__(self, vocab_size, block_size=8, n_embd=128, n_layer=4, n_head=4): | |
| super().__init__() | |
| self.token_emb = nn.Embedding(vocab_size, n_embd) | |
| self.pos_emb = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.ModuleList([ | |
| nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dropout=0.1) | |
| for _ in range(n_layer) | |
| ]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.head = nn.Linear(n_embd, vocab_size) | |
| self.block_size = block_size | |
| def forward(self, idx): | |
| b, t = idx.size() | |
| assert t <= self.block_size, "Sequence too long" | |
| pos = torch.arange(0, t, dtype=torch.long, device=idx.device) | |
| tok_emb = self.token_emb(idx) | |
| pos_emb = self.pos_emb(pos)[None, :, :] | |
| x = tok_emb + pos_emb | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| return logits |