hindi_chatbot / decoder_only_gpt.py
Anishss333's picture
Upload 15 files
a6d9791 verified
import torch
from torch import nn
from feed_forward_nn import feedforward, SwiGLU_FFN
from masked_mha import Masked_MHA
from rms_norm import RMSNorm
import math
# d_model = 512 # main model dimension
# num_heads = 8 # number of heads
# d_ff = 2048 # feedforward hidden dimension
# seq_len = 128 # max input length
# vocab_size = 30000
def generate_subsequent_mask(size):
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
mask = (~mask).unsqueeze(0).unsqueeze(1) # (1,1,L,L)
return mask
class Decoder_GPT_Block(nn.Module):
def __init__(self, d_model, d_ff, num_heads, seq_len, dropout=0.1):
super().__init__()
# self.ffn = feedforward(d_model, d_ff)
self.swi_glu = SwiGLU_FFN(d_model, d_ff)
self.masked_mha = Masked_MHA(d_model, num_heads, max_seq_len=seq_len)
self.rms_norm0 = RMSNorm(d_model)
self.rms_norm1 = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
# B, S, D = x.shape
# if mask is None:
# mask = generate_subsequent_mask(S).to(x.device) # (1,1,S,S)
# Masked Multi-Head Self Attention
# rms_norm_layer0_out = self.rms_norm0(x)
# masked_mha_out = self.masked_mha(rms_norm_layer0_out, mask)
h = self.rms_norm0(x)
h = self.masked_mha(h, mask)
# first Add & Norm (Residual connection)
# residual_1 = x + self.dropout(masked_mha_out)
# rms_norm_layer1_out = self.rms_norm1(residual_1)
x = x + self.dropout(h)
h = self.rms_norm1(x)
# Feed Forward Network
# ffn_out = self.ffn(rms_norm_layer1_out)
h = self.swi_glu(h)
# third Add & Norm (Residual connection)
# residual_2 = rms_norm_layer1_out + self.dropout(ffn_out)
x = x + self.dropout(h)
return x
class Decoder(nn.Module):
def __init__(self,vocab_size, num_layers, d_model, d_ff, num_heads,seq_len, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList(
[Decoder_GPT_Block(d_model, d_ff, num_heads, dropout)
for _ in range(num_layers)]
)
self.norm = RMSNorm(d_model)
self.seq_len = seq_len
self.register_buffer(
"causal_mask",
generate_subsequent_mask(seq_len)
)
# Original "Attention Is All You Need" paper did this
# Har block ke baad tum already Add & Norm karte ho, lekin last block ke output me fir bhi thoda drift (distribution shift) aa jata hai.
# Final LayerNorm output ko stabilize karta hai so that:
# output distribution consistent ho
# next layers (LM Head ya classifier) easily train ho
# gradients stable rahe
def forward_tokens(self, token_ids):
return self.embedding(token_ids)
def forward(self, x, mask=None):
"""
x : (B, S_dec, D)
enc_out : (B, S_enc, D)
tgt_mask: causal mask (1,1,S_dec,S_dec)
"""
B, S, D = x.shape
# if mask is None:
# mask = generate_subsequent_mask(S).to(x.device)
mask = self.causal_mask[:, :, :S, :S]
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class My_GPT_model(nn.Module):
def __init__(self, vocab_size, num_layers, d_model, d_ff, num_heads, seq_len, dropout=0.1):
super().__init__()
self.decoder = Decoder(
vocab_size=vocab_size, num_layers=num_layers, d_model=d_model,
d_ff=d_ff, num_heads=num_heads, seq_len=seq_len, dropout=dropout
)
# LM Head
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.decoder.embedding.weight
def forward(self, token_ids):
"""
token_ids: (B, S)
"""
# Token → Embedding
x = self.decoder.forward_tokens(token_ids) # (B, S, D)
# Decoder stack
x = self.decoder(x) # (B, S, D)
# LM Head → vocab logits
logits = self.lm_head(x) # (B, S, V)
return logits
# model = My_GPT_model(
# vocab_size=30000,
# num_layers=6,
# d_model=512,
# d_ff=2048,
# num_heads=8,
# seq_len=128
# )
# tokens = torch.randint(0, 30000, (2, 128))
# logits = model(tokens)
# print(logits.shape)
# # (2, 128, 30000)
# print(tokens)
# print("#################")
# print(logits)