| import torch | |
| import torch.nn as nn | |
| class MiniGPT(nn.Module): | |
| def __init__(self, vocab_size, d_model=456, n_heads=8, n_layers=4, max_len=256): | |
| super().__init__() | |
| self.token_embed = nn.Embedding(vocab_size, d_model) | |
| self.pos_embed = nn.Embedding(max_len, d_model) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| self.ln = nn.LayerNorm(d_model) | |
| self.fc_out = nn.Linear(d_model, vocab_size) | |
| def forward(self, input_ids): | |
| B, T = input_ids.shape | |
| pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) | |
| x = self.token_embed(input_ids) + self.pos_embed(pos) | |
| x = x.transpose(0, 1) # [T, B, D] | |
| x = self.transformer(x) | |
| x = x.transpose(0, 1) # [B, T, D] | |
| x = self.ln(x) | |
| return self.fc_out(x) | |
| def reset_params(self): | |
| for layer in self.children(): | |
| if hasattr(layer,'reset_parameters'): | |
| layer.reset_parameters() |