michaelbzhu commited on
Commit
15b5ad2
·
verified ·
1 Parent(s): c8ddd28

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +147 -0
modeling.py CHANGED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.attention import sdpa_kernel, SDPBackend
5
+
6
+ class RotaryPositionalEncoding(nn.Module):
7
+ """
8
+ Rotary Position Embeddings (RoPE) - efficient implementation
9
+ """
10
+ def __init__(self, d_head, max_seq_len=8192, base=10000.0):
11
+ super().__init__()
12
+ self.d_head = d_head
13
+ self.max_seq_len = max_seq_len
14
+ self.base = base
15
+
16
+ # Precompute inverse frequencies
17
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
18
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
19
+
20
+ # Precompute cos and sin for maximum sequence length
21
+ self._precompute_freqs(max_seq_len)
22
+
23
+ def _precompute_freqs(self, seq_len):
24
+ """Precompute cos and sin values for positions"""
25
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
26
+ freqs = torch.outer(t, self.inv_freq) # (seq_len, d_head/2)
27
+
28
+ # Create cos and sin embeddings
29
+ freqs_cos = torch.cos(freqs)
30
+ freqs_sin = torch.sin(freqs)
31
+
32
+ # Interleave to match the dimension (seq_len, d_head)
33
+ self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False)
34
+ self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False)
35
+
36
+ def rotate_half(self, x):
37
+ """Rotate half the hidden dims of the input"""
38
+ x1 = x[..., ::2]
39
+ x2 = x[..., 1::2]
40
+ return torch.stack([-x2, x1], dim=-1).flatten(-2)
41
+
42
+ def forward(self, q, k, start_pos=0):
43
+ """
44
+ Apply rotary embeddings to query and key tensors
45
+ Args:
46
+ q: (batch_size, n_heads, seq_len, d_head)
47
+ k: (batch_size, n_heads, seq_len, d_head)
48
+ start_pos: starting position for caching scenarios
49
+ Returns:
50
+ q_rot, k_rot with rotary embeddings applied
51
+ """
52
+ seq_len = q.shape[2]
53
+
54
+ # Get the precomputed frequencies for this sequence length
55
+ freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len]
56
+ freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len]
57
+
58
+ # Apply rotary embeddings
59
+ q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin
60
+ k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin
61
+
62
+ return q_rot, k_rot
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(self, d_model, n_heads, d_head):
66
+ super().__init__()
67
+ self.d_model = d_model
68
+ self.n_heads = n_heads
69
+ self.d_head = d_head
70
+
71
+ self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False)
72
+ self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False)
73
+ self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False)
74
+ self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False)
75
+
76
+ # Initialize RoPE
77
+ self.rope = RotaryPositionalEncoding(d_head)
78
+
79
+ def forward(self, x):
80
+ # x is shape batch_size, seq_len, d_model
81
+ batch_size, seq_len, d_model = x.shape
82
+ q = self.Wq(x) # q is shape batch_size, seq_len, n_heads * d_head
83
+ k = self.Wk(x)
84
+ v = self.Wv(x)
85
+
86
+ # reshape to batch_size, n_heads, seq_len, d_head
87
+ q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
88
+ k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
89
+ v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
90
+
91
+ q, k = self.rope(q, k)
92
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # ensure use flash attention
93
+ a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)# a is (batch_size, n_heads, seq_len, d_head)
94
+ a = a.transpose(1,2) # change a to (batch_size, seq_len, n_heads, d_head)
95
+ a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head)
96
+ out = self.Wo(a) # out is (batch_size, seq_len, d_model)
97
+ return out
98
+
99
+ class TransformerBlock(nn.Module):
100
+ def __init__(self, d_model, n_heads, d_head):
101
+ super().__init__()
102
+ self.d_model = d_model
103
+ self.n_heads = n_heads
104
+ self.d_head = d_head
105
+
106
+ self.attn = Attention(d_model, n_heads, d_head)
107
+ self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model))
108
+
109
+ self.norm1 = nn.RMSNorm(d_model)
110
+ self.norm2 = nn.RMSNorm(d_model)
111
+
112
+ def forward(self, x):
113
+ x = self.attn(self.norm1(x)) + x
114
+ x = self.mlp(self.norm2(x)) + x
115
+ return x
116
+
117
+ class GPT(nn.Module):
118
+ def __init__(self, d_model, n_heads, d_head, n_vocab, n_layers):
119
+ super().__init__()
120
+ self.d_model = d_model
121
+ self.n_heads = n_heads
122
+ self.d_head = d_head
123
+ self.n_vocab = n_vocab
124
+
125
+ self.embed = nn.Embedding(n_vocab, d_model)
126
+
127
+ self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)])
128
+
129
+ self.norm = nn.RMSNorm(d_model)
130
+ self.out_head = nn.Linear(d_model, n_vocab)
131
+
132
+ def forward(self, x):
133
+ x = self.embed(x)
134
+ for block in self.blocks:
135
+ x = block(x)
136
+ x = self.out_head(self.norm(x))
137
+ return x
138
+
139
+ class CustomModel(PreTrainedModel):
140
+ config_class = CustomConfig
141
+
142
+ def __init__(self, config):
143
+ super().__init__(config)
144
+ self.model = GPT(config.d_model, config.n_heads, config.d_head, config.n_vocab, config.n_layers)
145
+
146
+ def forward(self, tensor):
147
+ return self.model(tensor)