tensorfiend commited on
Commit
f206855
·
verified ·
1 Parent(s): 1d973c7

Upload modeling_dotlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dotlm.py +384 -0
modeling_dotlm.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, List, Union
5
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.cache_utils import Cache, DynamicCache
8
+
9
+
10
+ # ── Config ────────────────────────────────────────────────────────────────────
11
+
12
+ class DotLMConfig(PretrainedConfig):
13
+ model_type = "dotlm"
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size=16384,
18
+ d_model=768,
19
+ hidden_dim=2048,
20
+ num_hidden_layers=24,
21
+ n_heads=6,
22
+ n_kv_heads=2,
23
+ context_len=4096,
24
+ theta_base=10000.0,
25
+ norm_eps=1e-6,
26
+ initializer_range=0.02,
27
+ tie_word_embeddings=True,
28
+ **kwargs
29
+ ):
30
+ super().__init__(**kwargs)
31
+ self.vocab_size = vocab_size
32
+ self.d_model = d_model
33
+ self.hidden_dim = hidden_dim
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.n_heads = n_heads
36
+ self.n_kv_heads = n_kv_heads
37
+ self.context_len = context_len
38
+ self.theta_base = theta_base
39
+ self.norm_eps = norm_eps
40
+ self.initializer_range = initializer_range
41
+ self.tie_word_embeddings = tie_word_embeddings
42
+ self.use_cache = kwargs.get("use_cache", True)
43
+ self.pad_token_id = kwargs.get("pad_token_id", 0)
44
+ self.bos_token_id = kwargs.get("bos_token_id", None)
45
+ self.eos_token_id = kwargs.get("eos_token_id", 3)
46
+
47
+
48
+ # ── Architecture Components ───────────────────────────────────────────────────
49
+
50
+ def precompute_freqs_cis(dim, context_len, theta_base=10000.0):
51
+ theta = 1.0 / (theta_base ** (torch.arange(0, dim, 2) / dim))
52
+ seq_ids = torch.arange(context_len, dtype=torch.float32)
53
+ m_theta = torch.outer(seq_ids, theta)
54
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
55
+ return torch.cos(m_theta), torch.sin(m_theta)
56
+
57
+
58
+ class SwiGLU(nn.Module):
59
+ def __init__(self, d_model, hidden_dim):
60
+ super().__init__()
61
+ self.W = nn.Linear(d_model, hidden_dim, bias=False)
62
+ self.V = nn.Linear(d_model, hidden_dim, bias=False)
63
+ self.W2 = nn.Linear(hidden_dim, d_model, bias=False)
64
+ self.silu = nn.SiLU()
65
+
66
+ def forward(self, x):
67
+ return self.W2(self.silu(self.W(x)) * self.V(x))
68
+
69
+
70
+ class RMSNorm(nn.Module):
71
+ def __init__(self, dim, eps=1e-6):
72
+ super().__init__()
73
+ self.eps = eps
74
+ self.scale = nn.Parameter(torch.ones(dim))
75
+
76
+ def forward(self, x):
77
+ x = x * torch.rsqrt(torch.pow(x, 2).mean(dim=-1, keepdim=True) + self.eps)
78
+ return x * self.scale
79
+
80
+
81
+ class RoPE(nn.Module):
82
+ def forward(self, x, cos, sin):
83
+ batch_size, num_heads, seq_len, head_dim = x.shape
84
+ x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :]
85
+ x_rotated = torch.cat([-x2, x1], dim=-1)
86
+ return x * cos + x_rotated * sin
87
+
88
+
89
+ class GroupedQueryAttention(nn.Module):
90
+ def __init__(self, d_model, n_heads, head_dim, n_kv_groups):
91
+ super().__init__()
92
+ self.n_heads = n_heads
93
+ self.head_dim = head_dim
94
+ self.n_kv_groups = n_kv_groups
95
+ self.group_size = n_heads // n_kv_groups
96
+ self.output_dim = n_heads * head_dim
97
+
98
+ self.Wq = nn.Linear(d_model, self.output_dim, bias=False)
99
+ self.Wk = nn.Linear(d_model, n_kv_groups * head_dim, bias=False)
100
+ self.Wv = nn.Linear(d_model, n_kv_groups * head_dim, bias=False)
101
+ self.Wo = nn.Linear(self.output_dim, d_model, bias=False)
102
+ self.q_norm = RMSNorm(head_dim)
103
+ self.k_norm = RMSNorm(head_dim)
104
+ self.rope = RoPE()
105
+
106
+ def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False):
107
+ B, S, _ = x.shape
108
+
109
+ q = self.Wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
110
+ k = self.Wk(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2)
111
+ v = self.Wv(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2)
112
+
113
+ q, k = self.q_norm(q), self.k_norm(k)
114
+ q, k = self.rope(q, cos, sin), self.rope(k, cos, sin)
115
+
116
+ next_past = None
117
+ if past_key_value is not None:
118
+ if isinstance(past_key_value, Cache):
119
+ # HF DynamicCache: update in-place and get concatenated K/V back.
120
+ k, v = past_key_value.update(k, v, self.layer_idx)
121
+ next_past = past_key_value
122
+ else:
123
+ # Legacy cache format: (k, v) per layer. Some generation paths
124
+ # may pass placeholders like (None, None) on the first step.
125
+ pk, pv = past_key_value
126
+ if pk is not None:
127
+ k = torch.cat([pk, k], dim=2)
128
+ v = torch.cat([pv, v], dim=2)
129
+ next_past = (k, v) if use_cache else None
130
+
131
+ # Cache stores grouped K/V (n_kv_groups heads). We only expand for SDPA.
132
+ kv_k, kv_v = k, v
133
+
134
+ B, G, S_kv, D = kv_k.shape
135
+ k = kv_k.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D)
136
+ v = kv_v.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D)
137
+
138
+ # Causal logic for SDPA: if mask is None, we assume causality if prefill
139
+ # But for robustness, we always pass a mask if S > 1
140
+ is_causal = (mask is None and S > 1 and past_key_value is None)
141
+
142
+ out = F.scaled_dot_product_attention(
143
+ q, k, v,
144
+ attn_mask=None if (mask is None or is_causal) else ~mask,
145
+ dropout_p=0.0,
146
+ is_causal=is_causal,
147
+ )
148
+ out = out.transpose(1, 2).reshape(B, S, self.output_dim)
149
+ if use_cache and past_key_value is None:
150
+ # If we're not given a cache, return legacy K/V by default.
151
+ next_past = (kv_k, kv_v)
152
+ return self.Wo(out), next_past
153
+
154
+
155
+ class DotLMBlock(nn.Module):
156
+ def __init__(self, d_model, n_heads, n_kv_heads, hidden_dim, norm_eps=1e-6, layer_idx=None):
157
+ super().__init__()
158
+ head_dim = d_model // n_heads
159
+ self.attention = GroupedQueryAttention(d_model, n_heads, head_dim, n_kv_heads)
160
+ self.attention.layer_idx = layer_idx
161
+ self.feed_forward = SwiGLU(d_model, hidden_dim)
162
+ self.norm1 = RMSNorm(d_model, norm_eps)
163
+ self.norm2 = RMSNorm(d_model, norm_eps)
164
+
165
+ def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False):
166
+ residual = x
167
+ x = self.norm1(x)
168
+ attn_out, next_past = self.attention(x, cos, sin, mask, past_key_value, use_cache)
169
+ x = residual + attn_out
170
+
171
+ residual = x
172
+ x = self.norm2(x)
173
+ x = residual + self.feed_forward(x)
174
+ return x, next_past
175
+
176
+
177
+ # ── Flat HF Wrapper ───────────────────────────────────────────────────────────
178
+
179
+ class DotLMForCausalLM(PreTrainedModel, GenerationMixin):
180
+ config_class = DotLMConfig
181
+ # Let HF know output head is tied to embeddings when enabled.
182
+ _tied_weights_keys = {"head.weight": "embeddor.weight"}
183
+
184
+ def __init__(self, config):
185
+ super().__init__(config)
186
+ self.config = config
187
+
188
+ self.embeddor = nn.Embedding(config.vocab_size, config.d_model)
189
+ self.blocks = nn.ModuleList([
190
+ DotLMBlock(
191
+ config.d_model, config.n_heads, config.n_kv_heads,
192
+ config.hidden_dim, config.norm_eps, layer_idx=i
193
+ )
194
+ for i in range(config.num_hidden_layers)
195
+ ])
196
+ self.norm = RMSNorm(config.d_model, config.norm_eps)
197
+ self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
198
+
199
+ # Precompute RoPE
200
+ head_dim = config.d_model // config.n_heads
201
+ cos, sin = precompute_freqs_cis(head_dim, config.context_len, config.theta_base)
202
+ self.register_buffer("cos_cache", cos, persistent=False)
203
+ self.register_buffer("sin_cache", sin, persistent=False)
204
+
205
+ # Causal mask placeholder
206
+ mask = torch.triu(torch.ones(config.context_len, config.context_len, dtype=torch.bool), diagonal=1)
207
+ self.register_buffer("causal_mask", mask, persistent=False)
208
+
209
+ self.post_init()
210
+
211
+ def _ensure_rope_and_mask(self):
212
+ """
213
+ `from_pretrained(..., low_cpu_mem_usage=True)` may build the module under
214
+ meta tensors. In that case, our non-persistent buffers can end up as
215
+ meta/zero tensors even though they are deterministic. Recompute them on
216
+ demand.
217
+ """
218
+ need_rope = (
219
+ self.cos_cache.device.type == "meta"
220
+ or self.sin_cache.device.type == "meta"
221
+ or self.cos_cache.numel() == 0
222
+ or self.sin_cache.numel() == 0
223
+ or (self.cos_cache.numel() > 0 and float(self.cos_cache.flatten()[0]) == 0.0)
224
+ )
225
+ need_mask = (
226
+ self.causal_mask.device.type == "meta"
227
+ or self.causal_mask.numel() == 0
228
+ # causal_mask[0, 1] should be True for an upper-triangular mask.
229
+ or (self.causal_mask.numel() > 1 and bool(self.causal_mask[0, 1]) is False)
230
+ )
231
+ if not (need_rope or need_mask):
232
+ return
233
+
234
+ head_dim = self.config.d_model // self.config.n_heads
235
+ cos, sin = precompute_freqs_cis(head_dim, self.config.context_len, self.config.theta_base)
236
+ self._buffers["cos_cache"] = cos
237
+ self._buffers["sin_cache"] = sin
238
+
239
+ mask = torch.triu(
240
+ torch.ones(self.config.context_len, self.config.context_len, dtype=torch.bool), diagonal=1
241
+ )
242
+ self._buffers["causal_mask"] = mask
243
+
244
+ def _init_weights(self, module):
245
+ std = self.config.initializer_range
246
+ if isinstance(module, nn.Linear):
247
+ nn.init.normal_(module.weight, mean=0.0, std=std)
248
+ if module.bias is not None:
249
+ nn.init.zeros_(module.bias)
250
+ elif isinstance(module, nn.Embedding):
251
+ nn.init.normal_(module.weight, mean=0.0, std=std)
252
+
253
+ def tie_weights(self, **kwargs):
254
+ if self.config.tie_word_embeddings:
255
+ self.head.weight = self.embeddor.weight
256
+
257
+ def get_input_embeddings(self):
258
+ return self.embeddor
259
+
260
+ def set_input_embeddings(self, value):
261
+ self.embeddor = value
262
+ self.tie_weights()
263
+
264
+ def get_output_embeddings(self):
265
+ return self.head
266
+
267
+ def set_output_embeddings(self, new_embeddings):
268
+ self.head = new_embeddings
269
+ self.tie_weights()
270
+
271
+ def forward(
272
+ self,
273
+ input_ids: torch.LongTensor = None,
274
+ attention_mask: Optional[torch.Tensor] = None,
275
+ token_type_ids: Optional[torch.LongTensor] = None,
276
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
277
+ labels: Optional[torch.LongTensor] = None,
278
+ use_cache: Optional[bool] = None,
279
+ output_attentions: Optional[bool] = None,
280
+ output_hidden_states: Optional[bool] = None,
281
+ return_dict: Optional[bool] = None,
282
+ **kwargs
283
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
284
+
285
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
286
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
287
+ B, S = input_ids.shape
288
+
289
+ self._ensure_rope_and_mask()
290
+
291
+ # Support both HF Cache (v5+) and legacy tuple-of-layer-caches.
292
+ if use_cache and past_key_values is None:
293
+ past_key_values = DynamicCache()
294
+
295
+ # Positional tracking
296
+ start_pos = 0
297
+ if past_key_values is not None:
298
+ if isinstance(past_key_values, Cache):
299
+ start_pos = past_key_values.get_seq_length()
300
+ else:
301
+ layer0 = past_key_values[0]
302
+ if layer0 is not None and layer0[0] is not None:
303
+ start_pos = layer0[0].shape[2]
304
+
305
+ # Embeddings
306
+ x = self.embeddor(input_ids)
307
+
308
+ # RoPE slicing
309
+ cos = self.cos_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0)
310
+ sin = self.sin_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0)
311
+
312
+ # Masking
313
+ mask = None
314
+ if S > 1:
315
+ mask = self.causal_mask[start_pos : start_pos + S, : start_pos + S].to(device=x.device)
316
+
317
+ next_past_key_values = [] if (use_cache and not isinstance(past_key_values, Cache)) else None
318
+
319
+ # Blocks
320
+ for i, block in enumerate(self.blocks):
321
+ layer_past = None
322
+ if past_key_values is not None:
323
+ if isinstance(past_key_values, Cache):
324
+ layer_past = past_key_values
325
+ else:
326
+ layer_past = past_key_values[i]
327
+ x, new_layer_past = block(
328
+ x, cos, sin, mask=mask, past_key_value=layer_past, use_cache=use_cache
329
+ )
330
+ if next_past_key_values is not None:
331
+ next_past_key_values.append(new_layer_past)
332
+
333
+ # Final head
334
+ logits = self.head(self.norm(x))
335
+ if not self.training:
336
+ # Stability clip
337
+ logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
338
+
339
+ loss = None
340
+ if labels is not None:
341
+ shift_logits = logits[..., :-1, :].contiguous()
342
+ shift_labels = labels[..., 1:].contiguous()
343
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
344
+
345
+ if not return_dict:
346
+ return (logits, past_key_values) if use_cache else (logits,)
347
+
348
+ return CausalLMOutputWithPast(
349
+ loss=loss,
350
+ logits=logits,
351
+ past_key_values=past_key_values if isinstance(past_key_values, Cache) else (tuple(next_past_key_values) if use_cache else None)
352
+ )
353
+
354
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
355
+ past_len = 0
356
+ if past_key_values is not None:
357
+ if isinstance(past_key_values, Cache):
358
+ past_len = past_key_values.get_seq_length()
359
+ else:
360
+ layer0 = past_key_values[0] if len(past_key_values) > 0 else None
361
+ if layer0 is not None and layer0[0] is not None:
362
+ past_len = layer0[0].shape[2]
363
+
364
+ # Only slice for incremental decoding once we truly have cached history.
365
+ if past_len > 0:
366
+ input_ids = input_ids[:, -1:]
367
+ return {
368
+ "input_ids": input_ids,
369
+ "past_key_values": past_key_values,
370
+ "attention_mask": kwargs.get("attention_mask", None),
371
+ "token_type_ids": kwargs.get("token_type_ids", None),
372
+ "use_cache": True,
373
+ }
374
+
375
+ def _reorder_cache(self, past_key_values, beam_idx):
376
+ if past_key_values is None:
377
+ return past_key_values
378
+ if isinstance(past_key_values, Cache):
379
+ past_key_values.reorder_cache(beam_idx)
380
+ return past_key_values
381
+ return tuple(
382
+ (k.index_select(0, beam_idx), v.index_select(0, beam_idx))
383
+ for (k, v) in past_key_values
384
+ )