brianling16 commited on
Commit
f3d30f6
·
verified ·
1 Parent(s): 0b84fe0

Upload transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer.py +88 -0
transformer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ class MultiheadSelfAttention(nn.Module):
9
+ def __init__(self, d_model: int, n_heads: int):
10
+ super().__init__()
11
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
12
+ self.d_model = d_model
13
+ self.n_heads = n_heads
14
+ self.d_head = d_model // n_heads
15
+
16
+ # Standard projections
17
+ self.q_proj = nn.Linear(d_model, d_model)
18
+ self.k_proj = nn.Linear(d_model, d_model)
19
+ self.v_proj = nn.Linear(d_model, d_model)
20
+ self.out_proj = nn.Linear(d_model, d_model)
21
+
22
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
23
+ B, T, C = x.shape
24
+ H = self.n_heads
25
+ D = self.d_head
26
+
27
+ q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # (B, H, T, D)
28
+ k = self.k_proj(x).view(B, T, H, D).transpose(1, 2)
29
+ v = self.v_proj(x).view(B, T, H, D).transpose(1, 2)
30
+
31
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(D) # (B, H, T, T)
32
+ if attn_mask is not None:
33
+ att = att + attn_mask # mask should be broadcastable; use -inf on masked positions
34
+ att = F.softmax(att, dim=-1)
35
+ y = att @ v # (B, H, T, D)
36
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
37
+ y = self.out_proj(y)
38
+ return y
39
+
40
+ class MLP(nn.Module): # Fixed: Now inherits from nn.Module
41
+ def __init__(self, d_model: int, d_ff: int):
42
+ super().__init__()
43
+ self.fc1 = nn.Linear(d_model, d_ff)
44
+ self.fc2 = nn.Linear(d_ff, d_model)
45
+ self.activation = nn.ReLU()
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ return self.fc2(self.activation(self.fc1(x)))
49
+
50
+ class TransformerLayer(nn.Module):
51
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
52
+ super().__init__()
53
+ self.ln1 = nn.LayerNorm(d_model)
54
+ self.ln2 = nn.LayerNorm(d_model)
55
+ self.dropout = nn.Dropout(dropout)
56
+ self.self_attn = MultiheadSelfAttention(d_model, n_heads)
57
+ self.mlp = MLP(d_model, d_ff)
58
+
59
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
60
+ y = self.self_attn(self.ln1(x), attn_mask)
61
+ x = x + self.dropout(y)
62
+ y = self.mlp(self.ln2(x))
63
+ return x + self.dropout(y)
64
+
65
+ class Transformer(nn.Module):
66
+ def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: float = 0.1):
67
+ super().__init__()
68
+ self.d_model = d_model
69
+ self.n_heads = n_heads
70
+ self.n_layers = n_layers
71
+ self.d_ff = d_ff
72
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
73
+ self.pos_emb = nn.Embedding(2048, d_model) # simple fixed max length
74
+ self.layers = nn.ModuleList([
75
+ TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
76
+ ])
77
+ self.ln_f = nn.LayerNorm(d_model) # Added missing final LayerNorm
78
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
79
+ self.lm_head.weight = self.tok_emb.weight # weight tying
80
+
81
+ def forward(self, idx: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
82
+ B, T = idx.shape
83
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
84
+ x = self.tok_emb(idx) + self.pos_emb(pos)
85
+ for layer in self.layers:
86
+ x = layer(x, attn_mask)
87
+ x = self.ln_f(x)
88
+ return self.lm_head(x)