MiniGPT / model.py
CreatedNull's picture
Upload folder using huggingface_hub
4de3b20 verified
raw
history blame
1.11 kB
import torch
import torch.nn as nn
class MiniGPT(nn.Module):
def __init__(self, vocab_size, d_model=456, n_heads=8, n_layers=4, max_len=256):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_len, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.ln = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, input_ids):
B, T = input_ids.shape
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
x = self.token_embed(input_ids) + self.pos_embed(pos)
x = x.transpose(0, 1) # [T, B, D]
x = self.transformer(x)
x = x.transpose(0, 1) # [B, T, D]
x = self.ln(x)
return self.fc_out(x)
def reset_params(self):
for layer in self.children():
if hasattr(layer,'reset_parameters'):
layer.reset_parameters()