Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from feed_forward_nn import feedforward, SwiGLU_FFN | |
| from masked_mha import Masked_MHA | |
| from rms_norm import RMSNorm | |
| import math | |
| # d_model = 512 # main model dimension | |
| # num_heads = 8 # number of heads | |
| # d_ff = 2048 # feedforward hidden dimension | |
| # seq_len = 128 # max input length | |
| # vocab_size = 30000 | |
| def generate_subsequent_mask(size): | |
| mask = torch.triu(torch.ones(size, size), diagonal=1).bool() | |
| mask = (~mask).unsqueeze(0).unsqueeze(1) # (1,1,L,L) | |
| return mask | |
| class Decoder_GPT_Block(nn.Module): | |
| def __init__(self, d_model, d_ff, num_heads, seq_len, dropout=0.1): | |
| super().__init__() | |
| # self.ffn = feedforward(d_model, d_ff) | |
| self.swi_glu = SwiGLU_FFN(d_model, d_ff) | |
| self.masked_mha = Masked_MHA(d_model, num_heads, max_seq_len=seq_len) | |
| self.rms_norm0 = RMSNorm(d_model) | |
| self.rms_norm1 = RMSNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask): | |
| # B, S, D = x.shape | |
| # if mask is None: | |
| # mask = generate_subsequent_mask(S).to(x.device) # (1,1,S,S) | |
| # Masked Multi-Head Self Attention | |
| # rms_norm_layer0_out = self.rms_norm0(x) | |
| # masked_mha_out = self.masked_mha(rms_norm_layer0_out, mask) | |
| h = self.rms_norm0(x) | |
| h = self.masked_mha(h, mask) | |
| # first Add & Norm (Residual connection) | |
| # residual_1 = x + self.dropout(masked_mha_out) | |
| # rms_norm_layer1_out = self.rms_norm1(residual_1) | |
| x = x + self.dropout(h) | |
| h = self.rms_norm1(x) | |
| # Feed Forward Network | |
| # ffn_out = self.ffn(rms_norm_layer1_out) | |
| h = self.swi_glu(h) | |
| # third Add & Norm (Residual connection) | |
| # residual_2 = rms_norm_layer1_out + self.dropout(ffn_out) | |
| x = x + self.dropout(h) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self,vocab_size, num_layers, d_model, d_ff, num_heads,seq_len, dropout=0.1): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| self.layers = nn.ModuleList( | |
| [Decoder_GPT_Block(d_model, d_ff, num_heads, dropout) | |
| for _ in range(num_layers)] | |
| ) | |
| self.norm = RMSNorm(d_model) | |
| self.seq_len = seq_len | |
| self.register_buffer( | |
| "causal_mask", | |
| generate_subsequent_mask(seq_len) | |
| ) | |
| # Original "Attention Is All You Need" paper did this | |
| # Har block ke baad tum already Add & Norm karte ho, lekin last block ke output me fir bhi thoda drift (distribution shift) aa jata hai. | |
| # Final LayerNorm output ko stabilize karta hai so that: | |
| # output distribution consistent ho | |
| # next layers (LM Head ya classifier) easily train ho | |
| # gradients stable rahe | |
| def forward_tokens(self, token_ids): | |
| return self.embedding(token_ids) | |
| def forward(self, x, mask=None): | |
| """ | |
| x : (B, S_dec, D) | |
| enc_out : (B, S_enc, D) | |
| tgt_mask: causal mask (1,1,S_dec,S_dec) | |
| """ | |
| B, S, D = x.shape | |
| # if mask is None: | |
| # mask = generate_subsequent_mask(S).to(x.device) | |
| mask = self.causal_mask[:, :, :S, :S] | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| return self.norm(x) | |
| class My_GPT_model(nn.Module): | |
| def __init__(self, vocab_size, num_layers, d_model, d_ff, num_heads, seq_len, dropout=0.1): | |
| super().__init__() | |
| self.decoder = Decoder( | |
| vocab_size=vocab_size, num_layers=num_layers, d_model=d_model, | |
| d_ff=d_ff, num_heads=num_heads, seq_len=seq_len, dropout=dropout | |
| ) | |
| # LM Head | |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) | |
| # Weight tying | |
| self.lm_head.weight = self.decoder.embedding.weight | |
| def forward(self, token_ids): | |
| """ | |
| token_ids: (B, S) | |
| """ | |
| # Token → Embedding | |
| x = self.decoder.forward_tokens(token_ids) # (B, S, D) | |
| # Decoder stack | |
| x = self.decoder(x) # (B, S, D) | |
| # LM Head → vocab logits | |
| logits = self.lm_head(x) # (B, S, V) | |
| return logits | |
| # model = My_GPT_model( | |
| # vocab_size=30000, | |
| # num_layers=6, | |
| # d_model=512, | |
| # d_ff=2048, | |
| # num_heads=8, | |
| # seq_len=128 | |
| # ) | |
| # tokens = torch.randint(0, 30000, (2, 128)) | |
| # logits = model(tokens) | |
| # print(logits.shape) | |
| # # (2, 128, 30000) | |
| # print(tokens) | |
| # print("#################") | |
| # print(logits) |