Spaces:
Sleeping
Sleeping
File size: 4,748 Bytes
a6d9791 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import torch
from torch import nn
from feed_forward_nn import feedforward, SwiGLU_FFN
from masked_mha import Masked_MHA
from rms_norm import RMSNorm
import math
# d_model = 512 # main model dimension
# num_heads = 8 # number of heads
# d_ff = 2048 # feedforward hidden dimension
# seq_len = 128 # max input length
# vocab_size = 30000
def generate_subsequent_mask(size):
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
mask = (~mask).unsqueeze(0).unsqueeze(1) # (1,1,L,L)
return mask
class Decoder_GPT_Block(nn.Module):
def __init__(self, d_model, d_ff, num_heads, seq_len, dropout=0.1):
super().__init__()
# self.ffn = feedforward(d_model, d_ff)
self.swi_glu = SwiGLU_FFN(d_model, d_ff)
self.masked_mha = Masked_MHA(d_model, num_heads, max_seq_len=seq_len)
self.rms_norm0 = RMSNorm(d_model)
self.rms_norm1 = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
# B, S, D = x.shape
# if mask is None:
# mask = generate_subsequent_mask(S).to(x.device) # (1,1,S,S)
# Masked Multi-Head Self Attention
# rms_norm_layer0_out = self.rms_norm0(x)
# masked_mha_out = self.masked_mha(rms_norm_layer0_out, mask)
h = self.rms_norm0(x)
h = self.masked_mha(h, mask)
# first Add & Norm (Residual connection)
# residual_1 = x + self.dropout(masked_mha_out)
# rms_norm_layer1_out = self.rms_norm1(residual_1)
x = x + self.dropout(h)
h = self.rms_norm1(x)
# Feed Forward Network
# ffn_out = self.ffn(rms_norm_layer1_out)
h = self.swi_glu(h)
# third Add & Norm (Residual connection)
# residual_2 = rms_norm_layer1_out + self.dropout(ffn_out)
x = x + self.dropout(h)
return x
class Decoder(nn.Module):
def __init__(self,vocab_size, num_layers, d_model, d_ff, num_heads,seq_len, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList(
[Decoder_GPT_Block(d_model, d_ff, num_heads, dropout)
for _ in range(num_layers)]
)
self.norm = RMSNorm(d_model)
self.seq_len = seq_len
self.register_buffer(
"causal_mask",
generate_subsequent_mask(seq_len)
)
# Original "Attention Is All You Need" paper did this
# Har block ke baad tum already Add & Norm karte ho, lekin last block ke output me fir bhi thoda drift (distribution shift) aa jata hai.
# Final LayerNorm output ko stabilize karta hai so that:
# output distribution consistent ho
# next layers (LM Head ya classifier) easily train ho
# gradients stable rahe
def forward_tokens(self, token_ids):
return self.embedding(token_ids)
def forward(self, x, mask=None):
"""
x : (B, S_dec, D)
enc_out : (B, S_enc, D)
tgt_mask: causal mask (1,1,S_dec,S_dec)
"""
B, S, D = x.shape
# if mask is None:
# mask = generate_subsequent_mask(S).to(x.device)
mask = self.causal_mask[:, :, :S, :S]
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class My_GPT_model(nn.Module):
def __init__(self, vocab_size, num_layers, d_model, d_ff, num_heads, seq_len, dropout=0.1):
super().__init__()
self.decoder = Decoder(
vocab_size=vocab_size, num_layers=num_layers, d_model=d_model,
d_ff=d_ff, num_heads=num_heads, seq_len=seq_len, dropout=dropout
)
# LM Head
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.decoder.embedding.weight
def forward(self, token_ids):
"""
token_ids: (B, S)
"""
# Token → Embedding
x = self.decoder.forward_tokens(token_ids) # (B, S, D)
# Decoder stack
x = self.decoder(x) # (B, S, D)
# LM Head → vocab logits
logits = self.lm_head(x) # (B, S, V)
return logits
# model = My_GPT_model(
# vocab_size=30000,
# num_layers=6,
# d_model=512,
# d_ff=2048,
# num_heads=8,
# seq_len=128
# )
# tokens = torch.randint(0, 30000, (2, 128))
# logits = model(tokens)
# print(logits.shape)
# # (2, 128, 30000)
# print(tokens)
# print("#################")
# print(logits) |