dilip025 commited on
Commit
09d712b
·
verified ·
1 Parent(s): 0f1c2a1

Create model_code/decoder_only_transformer.py

Browse files
model_code/decoder_only_transformer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DecoderEmbeddings(nn.Module):
5
+ def __init__(self, vocab_size, embed_dim, max_len):
6
+ super().__init__()
7
+ self.token_embed = nn.Embedding(vocab_size, embed_dim)
8
+ self.pos_embed = nn.Embedding(max_len, embed_dim)
9
+ self.dropout = nn.Dropout(0.1)
10
+
11
+ def forward(self, input_ids):
12
+ seq_len = input_ids.size(1)
13
+ positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len]
14
+ token_embeddings = self.token_embed(input_ids) # [batch, seq_len, dim]
15
+ pos_embeddings = self.pos_embed(positions) # [1, seq_len, dim]
16
+ return self.dropout(token_embeddings + pos_embeddings)
17
+
18
+ def generate_causal_mask(seq_len, device):
19
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device)) # lower triangular
20
+ return mask == 0 # False = allow attend, True = mask
21
+
22
+ class MultiHeadSelfAttention(nn.Module):
23
+ def __init__(self, embed_dim, num_heads):
24
+ super().__init__()
25
+ assert embed_dim % num_heads == 0
26
+
27
+ self.num_heads = num_heads
28
+ self.head_dim = embed_dim // num_heads
29
+
30
+ self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
31
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
32
+
33
+ def forward(self, x, attn_mask=None):
34
+ batch_size, seq_len, embed_dim = x.size()
35
+
36
+ # Get Q, K, V
37
+ qkv = self.qkv_proj(x) # [B, T, 3 * D]
38
+ qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
39
+ qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, H, T, D]
40
+ q, k, v = qkv[0], qkv[1], qkv[2] # Each: [B, H, T, D]
41
+
42
+ # Attention scores
43
+ scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, H, T, T]
44
+
45
+ if attn_mask is not None:
46
+ scores = scores.masked_fill(attn_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
47
+ attn_weights = torch.softmax(scores, dim=-1) # [B, H, T, T]
48
+ attn_output = attn_weights @ v # [B, H, T, D]
49
+
50
+ # Merge heads
51
+ attn_output = attn_output.transpose(1, 2).contiguous() # [B, T, H, D]
52
+ attn_output = attn_output.view(batch_size, seq_len, embed_dim)
53
+
54
+ return self.out_proj(attn_output)
55
+
56
+ class FeedForward(nn.Module):
57
+ def __init__(self, embed_dim, ff_dim):
58
+ super().__init__()
59
+ self.net = nn.Sequential(
60
+ nn.Linear(embed_dim, ff_dim),
61
+ nn.GELU(),
62
+ nn.Linear(ff_dim, embed_dim)
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.net(x)
67
+
68
+ class DecoderBlock(nn.Module):
69
+ def __init__(self, embed_dim, num_heads, ff_dim):
70
+ super().__init__()
71
+ self.ln1 = nn.LayerNorm(embed_dim)
72
+ self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
73
+ self.ln2 = nn.LayerNorm(embed_dim)
74
+ self.ff = FeedForward(embed_dim, ff_dim)
75
+
76
+ def forward(self, x, attn_mask):
77
+ # Self-attention with residual
78
+ attn_out = self.attn(self.ln1(x), attn_mask)
79
+ x = x + attn_out
80
+
81
+ # Feedforward with residual
82
+ ff_out = self.ff(self.ln2(x))
83
+ x = x + ff_out
84
+
85
+ return x
86
+
87
+ class DecoderOnlyTransformer(nn.Module):
88
+ def __init__(self, vocab_size, max_len, embed_dim, num_heads, depth, ff_dim):
89
+ super().__init__()
90
+ self.embedding = DecoderEmbeddings(vocab_size, embed_dim, max_len)
91
+
92
+ self.blocks = nn.ModuleList([
93
+ DecoderBlock(embed_dim, num_heads, ff_dim)
94
+ for _ in range(depth)
95
+ ])
96
+
97
+ self.ln_final = nn.LayerNorm(embed_dim)
98
+ self.head = nn.Linear(embed_dim, vocab_size) # Language modeling head
99
+
100
+ def forward(self, input_ids):
101
+ """
102
+ input_ids: [B, T]
103
+ """
104
+ B, T = input_ids.size()
105
+ x = self.embedding(input_ids) # [B, T, D]
106
+
107
+ # Generate causal mask: True where mask is applied
108
+ mask = generate_causal_mask(T, input_ids.device)
109
+
110
+ for block in self.blocks:
111
+ x = block(x, attn_mask=mask)
112
+
113
+ x = self.ln_final(x) # [B, T, D]
114
+ logits = self.head(x) # [B, T, vocab_size]
115
+
116
+ return logits