gate369 commited on
Commit
c7eb2e2
·
verified ·
1 Parent(s): 6cdb8f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +213 -18
README.md CHANGED
@@ -54,25 +54,220 @@ tokenizer = AutoTokenizer.from_pretrained(model_path)
54
  # For convenience, the model definition is included in the training script.
55
  # Here we provide a minimal loading snippet assuming you have the model class.
56
 
57
- # Define model config (must match the saved config.json)
 
 
58
  class ModelConfig:
59
- vocab_size = 50257
60
- emb_dim = 768
61
- hidden_dim = 2048
62
- num_layers = 12
63
- num_heads = 12
64
- num_kv_heads = 4
65
- max_seq_len = 1024
66
- window_size = 1024
67
- sliding_window_ratio = 0.75
68
- rope_theta = 10000.0
69
- dtype = torch.float16
70
- bias = False
71
- dropout = 0.0
72
-
73
- # Instantiate model (you need the model class definition, e.g., TinyAya)
74
- # Here we assume you have the TinyAya class from the training script.
75
- # If not, copy the class definition from the training script into this cell.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  model = TinyAya(ModelConfig())
77
  state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu")
78
  model.load_state_dict(state_dict)
 
54
  # For convenience, the model definition is included in the training script.
55
  # Here we provide a minimal loading snippet assuming you have the model class.
56
 
57
+ # ------------------------------------------------------------------------------
58
+ # Configuration (scaled to ~150M for L4 GPU)
59
+ # ------------------------------------------------------------------------------
60
  class ModelConfig:
61
+ vocab_size = 50257 # will be updated from tokenizer
62
+ emb_dim = 768 # embedding dimension
63
+ hidden_dim = 2048 # intermediate size (FFN) - reduced
64
+ num_layers = 12 # number of transformer layers - reduced
65
+ num_heads = 12 # number of query heads - reduced
66
+ num_kv_heads = 4 # number of key/value heads (GQA)
67
+ max_seq_len = 1024 # shorter sequence length to save memory
68
+ window_size = 1024 # sliding window size (match max_seq_len)
69
+ sliding_window_ratio = 0.75 # fraction of layers with sliding window
70
+ rope_theta = 10000.0 # base for RoPE
71
+ dtype = torch.float16 # use mixed precision
72
+ bias = False # no bias in linear layers
73
+ dropout = 0.0 # no dropout mentioned
74
+ gradient_checkpointing = True # enable to save memory
75
+
76
+ # ------------------------------------------------------------------------------
77
+ # Helper modules (unchanged)
78
+ # ------------------------------------------------------------------------------
79
+ class CohereLayerNorm(nn.Module):
80
+ """LayerNorm without bias (scale only)."""
81
+ def __init__(self, emb_dim, eps=1e-5):
82
+ super().__init__()
83
+ self.eps = eps
84
+ self.weight = nn.Parameter(torch.ones(emb_dim))
85
+
86
+ def forward(self, x):
87
+ input_dtype = x.dtype
88
+ x = x.to(torch.float32)
89
+ mean = x.mean(dim=-1, keepdim=True)
90
+ variance = (x - mean).pow(2).mean(dim=-1, keepdim=True)
91
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
92
+ return (self.weight.to(torch.float32) * x).to(input_dtype)
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ """SwiGLU MLP."""
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.fc1 = nn.Linear(config.emb_dim, config.hidden_dim, bias=config.bias)
100
+ self.fc2 = nn.Linear(config.emb_dim, config.hidden_dim, bias=config.bias)
101
+ self.fc3 = nn.Linear(config.hidden_dim, config.emb_dim, bias=config.bias)
102
+
103
+ def forward(self, x):
104
+ x_fc1 = self.fc1(x)
105
+ x_fc2 = self.fc2(x)
106
+ x = F.silu(x_fc1) * x_fc2
107
+ return self.fc3(x)
108
+
109
+
110
+ def precompute_rope_freqs(dim, max_seq_len, theta=10000.0, dtype=torch.float32):
111
+ """Precompute rotary position embeddings."""
112
+ assert dim % 2 == 0, "dim must be even"
113
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[:(dim // 2)] / dim))
114
+ t = torch.arange(max_seq_len, dtype=dtype)
115
+ freqs = torch.outer(t, freqs) # shape (max_seq_len, dim//2)
116
+ emb = torch.cat((freqs, freqs), dim=-1) # shape (max_seq_len, dim)
117
+ return emb.sin(), emb.cos()
118
+
119
+
120
+ def rotate_half(x):
121
+ """Rotates half the hidden dims of the input."""
122
+ x1, x2 = x.chunk(2, dim=-1)
123
+ return torch.cat((-x2, x1), dim=-1)
124
+
125
+
126
+ def apply_rotary_emb(x, cos, sin):
127
+ """
128
+ Apply rotary embeddings to input tensor.
129
+ x: (batch, seq_len, num_heads, head_dim)
130
+ cos, sin: (seq_len, head_dim)
131
+ """
132
+ cos = cos.unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim)
133
+ sin = sin.unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim)
134
+ return (x * cos) + (rotate_half(x) * sin)
135
+
136
+
137
+ class GroupedQueryAttention(nn.Module):
138
+ """Multi-head attention with GQA and optional sliding window mask."""
139
+ def __init__(self, config, layer_id):
140
+ super().__init__()
141
+ self.num_heads = config.num_heads
142
+ self.num_kv_heads = config.num_kv_heads
143
+ self.head_dim = config.emb_dim // config.num_heads
144
+ assert self.num_heads % self.num_kv_heads == 0
145
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
146
+
147
+ self.wq = nn.Linear(config.emb_dim, config.num_heads * self.head_dim, bias=config.bias)
148
+ self.wk = nn.Linear(config.emb_dim, config.num_kv_heads * self.head_dim, bias=config.bias)
149
+ self.wv = nn.Linear(config.emb_dim, config.num_kv_heads * self.head_dim, bias=config.bias)
150
+ self.wo = nn.Linear(config.num_heads * self.head_dim, config.emb_dim, bias=config.bias)
151
+
152
+ total_layers = config.num_layers
153
+ num_sliding = int(total_layers * config.sliding_window_ratio)
154
+ self.use_sliding = (layer_id < num_sliding)
155
+
156
+ self.window_size = config.window_size
157
+ self.max_seq_len = config.max_seq_len
158
+ self.rope_theta = config.rope_theta
159
+ self.rope_sin, self.rope_cos = None, None
160
+
161
+ def init_rope(self, max_seq_len, device):
162
+ if self.rope_sin is not None and self.rope_sin.shape[0] >= max_seq_len:
163
+ return
164
+ sin, cos = precompute_rope_freqs(
165
+ self.head_dim, max_seq_len, theta=self.rope_theta, dtype=torch.float32
166
+ )
167
+ self.rope_sin = sin.to(device)
168
+ self.rope_cos = cos.to(device)
169
+
170
+ def forward(self, x, mask=None):
171
+ batch, seq_len, _ = x.shape
172
+ device = x.device
173
+
174
+ if self.use_sliding:
175
+ self.init_rope(seq_len, device)
176
+
177
+ xq = self.wq(x)
178
+ xk = self.wk(x)
179
+ xv = self.wv(x)
180
+
181
+ xq = xq.view(batch, seq_len, self.num_heads, self.head_dim)
182
+ xk = xk.view(batch, seq_len, self.num_kv_heads, self.head_dim)
183
+ xv = xv.view(batch, seq_len, self.num_kv_heads, self.head_dim)
184
+
185
+ if self.use_sliding:
186
+ xq = apply_rotary_emb(xq, self.rope_cos[:seq_len], self.rope_sin[:seq_len])
187
+ xk = apply_rotary_emb(xk, self.rope_cos[:seq_len], self.rope_sin[:seq_len])
188
+
189
+ xk = xk.repeat_interleave(self.num_queries_per_kv, dim=2)
190
+ xv = xv.repeat_interleave(self.num_queries_per_kv, dim=2)
191
+
192
+ xq = xq.transpose(1, 2)
193
+ xk = xk.transpose(1, 2)
194
+ xv = xv.transpose(1, 2)
195
+
196
+ scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
197
+
198
+ if mask is not None:
199
+ scores = scores + mask
200
+ else:
201
+ mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
202
+ mask = torch.triu(mask, diagonal=1)
203
+ if self.use_sliding:
204
+ for i in range(seq_len):
205
+ low = max(0, i - self.window_size + 1)
206
+ mask[i, :low] = float('-inf')
207
+ scores = scores + mask
208
+
209
+ probs = F.softmax(scores, dim=-1, dtype=torch.float32).to(xq.dtype)
210
+ out = torch.matmul(probs, xv)
211
+ out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
212
+ return self.wo(out)
213
+
214
+
215
+ class ParallelTransformerBlock(nn.Module):
216
+ """Decoder block with parallel attention and MLP."""
217
+ def __init__(self, config, layer_id):
218
+ super().__init__()
219
+ self.norm = CohereLayerNorm(config.emb_dim)
220
+ self.attn = GroupedQueryAttention(config, layer_id)
221
+ self.mlp = FeedForward(config)
222
+
223
+ def forward(self, x, mask=None):
224
+ residual = x
225
+ x = self.norm(x)
226
+ attn_out = self.attn(x, mask=mask)
227
+ mlp_out = self.mlp(x)
228
+ return residual + attn_out + mlp_out
229
+
230
+
231
+ class TinyAya(nn.Module):
232
+ """Tiny Aya 150M model."""
233
+ def __init__(self, config):
234
+ super().__init__()
235
+ self.config = config
236
+ self.token_embedding = nn.Embedding(config.vocab_size, config.emb_dim)
237
+ self.layers = nn.ModuleList([
238
+ ParallelTransformerBlock(config, i) for i in range(config.num_layers)
239
+ ])
240
+ self.norm = CohereLayerNorm(config.emb_dim)
241
+ self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
242
+ self.lm_head.weight = self.token_embedding.weight
243
+
244
+ if config.gradient_checkpointing:
245
+ self.gradient_checkpointing_enable()
246
+
247
+ def gradient_checkpointing_enable(self):
248
+ self._gradient_checkpointing = True
249
+
250
+ def forward(self, input_ids, mask=None):
251
+ x = self.token_embedding(input_ids)
252
+ for layer in self.layers:
253
+ if self.training and getattr(self, '_gradient_checkpointing', False):
254
+ x = torch.utils.checkpoint.checkpoint(layer, x, mask)
255
+ else:
256
+ x = layer(x, mask=mask)
257
+ x = self.norm(x)
258
+ logits = self.lm_head(x)
259
+ return logits
260
+
261
+ @torch.no_grad()
262
+ def generate(self, input_ids, max_new_tokens=50, temperature=1.0):
263
+ self.eval()
264
+ for _ in range(max_new_tokens):
265
+ logits = self(input_ids[:, -self.config.max_seq_len:])
266
+ next_token_logits = logits[:, -1, :] / temperature
267
+ probs = F.softmax(next_token_logits, dim=-1)
268
+ next_token = torch.multinomial(probs, num_samples=1)
269
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
270
+ return input_ids
271
  model = TinyAya(ModelConfig())
272
  state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu")
273
  model.load_state_dict(state_dict)