| import torch
|
| import torch.nn as nn
|
|
|
| class MiniGPT(nn.Module):
|
| def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512):
|
| super().__init__()
|
|
|
| self.token_embed = nn.Embedding(vocab_size, d_model)
|
| self.pos_embed = nn.Embedding(max_len, d_model)
|
|
|
|
|
|
|
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False)
|
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
|
|
| self.ln = nn.LayerNorm(d_model)
|
| self.fc_out = nn.Linear(d_model, vocab_size)
|
|
|
| def generate_causal_mask(self, T, device):
|
|
|
| return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
|
|
|
| def forward(self, input_ids):
|
| B, T = input_ids.shape
|
| pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
| x = self.token_embed(input_ids) + self.pos_embed(pos)
|
| x = x.transpose(0, 1)
|
|
|
|
|
| mask = self.generate_causal_mask(T, input_ids.device)
|
|
|
| x = self.transformer(x, mask)
|
| x = x.transpose(0, 1)
|
| x = self.ln(x)
|
| return self.fc_out(x)
|
|
|
| def reset_params(self):
|
| for layer in self.children():
|
| if hasattr(layer, 'reset_parameters'):
|
| layer.reset_parameters() |