mjerome89 commited on
Commit
e5d3914
·
verified ·
1 Parent(s): 66071cc

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +348 -645
model.py CHANGED
@@ -9,13 +9,17 @@ if not hasattr(torch.library, 'wrap_triton'):
9
  import torch._dynamo
10
  torch._dynamo.config.capture_scalar_outputs = True
11
 
 
12
  import torch.nn as nn
13
  import torch.nn.functional as F
 
 
 
 
 
14
  from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
-
17
- from transformers import PreTrainedModel, PretrainedConfig
18
- from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPast, SequenceClassifierOutput
19
 
20
  import bert_padding
21
  from attention import FlexBertUnpadRopeAttention
@@ -24,157 +28,28 @@ from torch.distributed import init_process_group, destroy_process_group
24
  from torch.nn.parallel import DistributedDataParallel as DDP
25
  import torch.distributed as dist
26
 
27
- try:
28
- from liger_kernel.transformers import LigerLayerNorm
29
- LayerNormClass = LigerLayerNorm
30
- except ImportError:
31
- LayerNormClass = nn.LayerNorm
32
-
33
-
34
-
35
- # ==============================================================================
36
- # HuggingFace-compatible Configuration
37
- # ==============================================================================
38
-
39
- class CustomTransformerConfig(PretrainedConfig):
40
- """
41
- Configuration class for CustomTransformer model.
42
-
43
- This class stores the configuration of a CustomTransformer model and is compatible
44
- with HuggingFace's transformers library. It replaces the old ModelConfig dataclass.
45
- """
46
- model_type = "custom_transformer"
47
-
48
- # auto_map tells HF which classes to use when loading with AutoModel/AutoConfig
49
- auto_map = {
50
- "AutoConfig": "model.CustomTransformerConfig",
51
- "AutoModel": "model.CustomTransformerModel",
52
- "AutoModelForMaskedLM": "model.CustomTransformerForMaskedLM",
53
- "AutoModelForSequenceClassification": "model.CustomTransformerForSequenceClassification",
54
- }
55
-
56
- def __init__(
57
- self,
58
- vocab_size: int = 50368,
59
- num_dims: int = 768,
60
- num_heads: int = 12,
61
- num_kv_heads: int = 12,
62
- num_layers: int = 12,
63
- ffn_hidden_dims: int = 1536,
64
- layernorm_eps: float = 1e-6,
65
- attention_probs_dropout_prob: float = 0.1,
66
- attn_qkv_bias: bool = False,
67
- attn_out_bias: bool = False,
68
- attn_out_dropout_prob: float = 0.0,
69
- global_attn_every_n_layers: int = 3,
70
- sliding_window: int = 128,
71
- rotary_emb_base: int = 10000,
72
- context_len: int = 128,
73
- use_cache: bool = False,
74
- use_flash: bool = True,
75
- use_moe: bool = True,
76
- moe_num_experts: int = 15,
77
- moe_routed_experts: int = 1,
78
- moe_eps: float = 1e-6,
79
- moe_aux_loss_coef: float = 0.01,
80
- moe_shared_experts: int = 1,
81
- use_lossfreebalance: bool = True,
82
- pad_token_id: int = 0,
83
- bos_token_id: int = 1,
84
- eos_token_id: int = 2,
85
- mask_token_id: int = 3,
86
- rope_theta: float = 1e5,
87
- ffn_dim_multiplier: Optional[int] = None,
88
- rotary_emb_dim: Optional[int] = None,
89
- local_attn_rotary_emb_base: int = -1,
90
- local_attn_rotary_emb_dim: Optional[int] = None,
91
- rotary_emb_scale_base: Optional[float] = None,
92
- rotary_emb_interleaved: bool = False,
93
- use_fa2: Optional[bool] = None,
94
- deterministic_fa2: bool = False,
95
- use_sdpa_attn_mask: bool = False,
96
- num_labels: int = 2,
97
- classifier_dropout: Optional[float] = None,
98
- **kwargs
99
- ):
100
- """Initialize CustomTransformerConfig."""
101
- super().__init__(
102
- pad_token_id=pad_token_id,
103
- bos_token_id=bos_token_id,
104
- eos_token_id=eos_token_id,
105
- **kwargs
106
- )
107
-
108
- self.vocab_size = vocab_size
109
- self.num_dims = num_dims
110
- self.num_heads = num_heads
111
- self.num_kv_heads = num_kv_heads
112
- self.num_layers = num_layers
113
- self.ffn_hidden_dims = ffn_hidden_dims
114
- self.layernorm_eps = layernorm_eps
115
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
116
- self.attn_qkv_bias = attn_qkv_bias
117
- self.attn_out_bias = attn_out_bias
118
- self.attn_out_dropout_prob = attn_out_dropout_prob
119
- self.global_attn_every_n_layers = global_attn_every_n_layers
120
- self.sliding_window = sliding_window
121
- self.rotary_emb_base = rotary_emb_base
122
- self.context_len = context_len
123
- self.use_cache = use_cache
124
- self.use_flash = use_flash
125
- self.use_moe = use_moe
126
- self.moe_num_experts = moe_num_experts
127
- self.moe_routed_experts = moe_routed_experts
128
- self.moe_eps = moe_eps
129
- self.moe_aux_loss_coef = moe_aux_loss_coef
130
- self.moe_shared_experts = moe_shared_experts
131
- self.use_lossfreebalance = use_lossfreebalance
132
- self.mask_token_id = mask_token_id
133
- self.rope_theta = rope_theta
134
- self.ffn_dim_multiplier = ffn_dim_multiplier
135
- self.rotary_emb_dim = rotary_emb_dim
136
- self.local_attn_rotary_emb_base = local_attn_rotary_emb_base
137
- self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim
138
- self.rotary_emb_scale_base = rotary_emb_scale_base
139
- self.rotary_emb_interleaved = rotary_emb_interleaved
140
- self.use_fa2 = use_fa2
141
- self.deterministic_fa2 = deterministic_fa2
142
- self.use_sdpa_attn_mask = use_sdpa_attn_mask
143
- self.num_labels = num_labels
144
- self.classifier_dropout = classifier_dropout
145
-
146
- # Derived attributes for compatibility with attention module
147
- self.hidden_size = num_dims
148
- self.num_attention_heads = num_heads
149
- self.embedding_size = num_dims
150
-
151
- # Mirror old ModelConfig.__post_init__
152
- if self.use_fa2 is None:
153
- self.use_fa2 = self.use_flash
154
-
155
 
156
- # Keep ModelConfig as a thin alias for backward compatibility with existing training scripts
157
  @dataclass
158
  class ModelConfig:
159
  vocab_size: int
160
 
161
- num_dims: int
162
- num_heads: int
163
- num_kv_heads: int
164
- num_layers: int
165
- ffn_hidden_dims: int
166
 
167
- context_len: int
168
- use_cache: bool
169
- use_flash: bool
170
- use_moe: bool
171
 
172
- moe_num_experts: int
173
- moe_routed_experts: int
174
- moe_eps: float = 1e-6
175
- moe_aux_loss_coef: float = 0.01
176
- moe_shared_experts: int = 0
177
- use_lossfreebalance: bool = False
178
 
179
  layernorm_eps: float = 1e-6
180
  rope_theta: float = 1e5
@@ -198,7 +73,7 @@ class ModelConfig:
198
  num_attention_heads: Optional[int] = None
199
  embedding_size: Optional[int] = None
200
 
201
- ffn_dim_multiplier: Optional[int] = None
202
 
203
  def __post_init__(self):
204
  if self.hidden_size is None:
@@ -211,53 +86,206 @@ class ModelConfig:
211
  self.use_fa2 = self.use_flash
212
 
213
 
214
- # ==============================================================================
215
- # Model Layers
216
- # ==============================================================================
217
 
218
- class FlexBertUnpadAttention(nn.Module):
219
- """Thin wrapper that preserves the state_dict key path: block.attention.attn.*
 
 
 
 
 
 
 
 
 
220
 
221
- In ModernBERT-style global unpadding the data is already (total_nnz, dim) so
222
- this wrapper just forwards directly to FlexBertUnpadRopeAttention without
223
- any pad/unpad work. cu_seqlens, max_seqlen, indices, and attn_mask are
224
- passed through from the Transformer level.
225
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def __init__(self, config, layer_id: Optional[int] = None):
227
  super().__init__()
228
  self.attn = FlexBertUnpadRopeAttention(config=config, layer_id=layer_id)
229
 
230
- def forward(
231
- self,
232
- hidden_states: torch.Tensor,
233
- cu_seqlens: torch.Tensor,
234
- max_seqlen: int,
235
- indices: torch.Tensor,
236
- attn_mask: torch.Tensor,
237
- ) -> torch.Tensor:
238
- """Forward on already-unpadded data.
239
-
240
- Args:
241
- hidden_states: (total_nnz, dim)
242
- cu_seqlens: (batch + 1,)
243
- max_seqlen: int
244
- indices: (total_nnz,)
245
- attn_mask: (batch, seq_len)
246
-
247
- Returns:
248
- (total_nnz, dim)
249
- """
250
- return self.attn(
251
  hidden_states=hidden_states,
252
  cu_seqlens=cu_seqlens,
253
  max_seqlen=max_seqlen,
254
  indices=indices,
255
  attn_mask=attn_mask,
256
  )
 
257
 
258
 
259
  class FeedForward(nn.Module):
260
- """Default Feed Forward Layer. Works on both 2D (total_nnz, dim) and 3D inputs."""
 
 
261
  def __init__(self, config):
262
  super().__init__()
263
 
@@ -267,54 +295,40 @@ class FeedForward(nn.Module):
267
  self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False)
268
  self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)
269
  self.act = nn.GELU()
270
-
271
  def forward(self, x: torch.Tensor):
272
  return self.w2(self.act(self.w1(x)) * self.w3(x)), None
273
 
274
 
275
- class FFNwMoE(nn.Module):
276
  """
277
  Feed Forward with MoE with optional shared experts.
278
- Works on 2D (total_nnz, dim) unpadded inputs.
279
-
280
- Uses batched_mm (torch.bmm) for expert dispatch. Expert weights are stored
281
- as stacked nn.Parameters: (num_experts, out_dim, in_dim). Old checkpoints
282
- with per-expert nn.Linear weights are automatically converted at load time
283
- via _load_from_state_dict.
284
-
285
  Returns after forward:
286
  output: Combined outputs from experts
287
  aux_loss: Auxiliary loss tensor or routing metadata
288
  """
289
- def __init__(self, config):
290
  super().__init__()
291
  self.hidden_dim = config.ffn_hidden_dims
292
- self.num_dims = config.num_dims
293
 
294
- self.moe_routed_experts = config.moe_routed_experts
295
  self.moe_aux_loss_coef = config.moe_aux_loss_coef
296
  self.moe_eps = config.moe_eps
297
  self.moe_shared_experts = config.moe_shared_experts
298
  self.num_experts = config.moe_num_experts
299
 
300
- self.use_lossfreebalance = config.use_lossfreebalance
301
 
302
- self.router = nn.Linear(config.num_dims, self.num_experts, bias=False)
303
-
304
- # Stacked expert weights — the actual trainable parameters
305
- # w1: projects dim -> hidden (gate)
306
- # w2: projects hidden -> dim (down)
307
- # w3: projects dim -> hidden (up)
308
- self.w1_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims))
309
- self.w2_stacked = nn.Parameter(torch.empty(self.num_experts, config.num_dims, self.hidden_dim))
310
- self.w3_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims))
311
-
312
- # Initialize
313
- for i in range(self.num_experts):
314
- nn.init.kaiming_uniform_(self.w1_stacked.data[i])
315
- nn.init.kaiming_uniform_(self.w2_stacked.data[i])
316
- nn.init.kaiming_uniform_(self.w3_stacked.data[i])
317
 
 
 
 
 
 
 
 
 
 
 
318
  # shared experts (for DeepSeekMoE)
319
  self.shared_experts = nn.ModuleList()
320
  for _ in range(self.moe_shared_experts):
@@ -324,23 +338,17 @@ class FFNwMoE(nn.Module):
324
  nn.Linear(self.hidden_dim, config.num_dims, bias=False),
325
  nn.Linear(config.num_dims, self.hidden_dim, bias=False)
326
  ]))
327
-
328
- # Auxiliary-loss-free load balancing strategy for MoE (DeepSeek)
329
  if self.use_lossfreebalance:
330
  self.expert_biases = nn.Parameter(torch.zeros(self.num_experts))
331
-
332
  def forward(self, x: torch.Tensor):
333
- # x can be (total_nnz, dim) or (batch, seq_len, dim)
334
- input_shape = x.shape
335
- if x.ndim == 3:
336
- c_batch_size, c_context_len, c_dim = input_shape
337
- x_flat = x.view(-1, c_dim)
338
- else:
339
- x_flat = x
340
- c_dim = x.shape[-1]
341
 
342
  router_out = self.router(x_flat)
343
- router_probs = F.softmax(router_out, dim=-1)
344
 
345
  _, topk_indices = router_out.topk(self.moe_routed_experts, dim=-1)
346
  self.last_topk_indices = topk_indices.detach()
@@ -349,13 +357,12 @@ class FFNwMoE(nn.Module):
349
 
350
  output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs)
351
 
352
- if x.ndim == 3:
353
- output = output.view(c_batch_size, c_context_len, c_dim)
354
-
355
- return output, aux_loss
356
 
357
  def _compute_aux_loss(self, router_out, router_probs, topk_indices):
358
- """Computes the auxiliary loss based on whether loss-free balancing is used or not."""
 
 
359
  if not self.use_lossfreebalance:
360
  topk_probs, _ = router_probs.topk(self.moe_routed_experts, dim=-1)
361
  expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float()
@@ -363,80 +370,47 @@ class FFNwMoE(nn.Module):
363
  router_prob_mean = router_probs.mean(dim=0)
364
  aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts
365
 
366
- else:
367
  router_out = router_out + self.expert_biases
368
- router_probs = torch.sigmoid(router_out)
369
  topk_probs = router_probs.gather(-1, topk_indices)
370
  topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
371
 
 
372
  aux_loss = (router_probs, topk_indices)
373
  return aux_loss, topk_probs
374
 
375
  def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs):
376
- """Compute expert outputs using sort-based dispatch with stacked weights.
377
-
378
- Sort tokens by expert, slice contiguous chunks, run each expert via
379
- matmul on the stacked weight tensors. No weight duplication, minimal
380
- memory overhead.
381
  """
382
- num_tokens, dim = x_flat.shape
383
-
384
- # Flatten top-k: (num_tokens * top_k,)
385
- flat_expert_ids = topk_indices.view(-1)
386
- flat_probs = topk_probs.view(-1)
387
- flat_token_ids = torch.arange(num_tokens, device=x_flat.device).unsqueeze(1).expand(-1, self.moe_routed_experts).reshape(-1)
388
-
389
- # Sort by expert id for contiguous batching
390
- sorted_expert_ids, sort_indices = flat_expert_ids.sort(stable=True)
391
- sorted_token_ids = flat_token_ids[sort_indices]
392
- sorted_probs = flat_probs[sort_indices]
393
-
394
- # Gather sorted input tokens
395
- sorted_x = x_flat[sorted_token_ids] # (num_tokens * top_k, dim)
396
-
397
- # Find expert boundaries
398
- expert_counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
399
- expert_offsets = torch.zeros(self.num_experts + 1, dtype=torch.long, device=x_flat.device)
400
- torch.cumsum(expert_counts, dim=0, out=expert_offsets[1:])
401
-
402
- # Run each expert on its contiguous slice using stacked weights
403
- sorted_output = torch.zeros_like(sorted_x)
404
- for expert_id in range(self.num_experts):
405
- start = expert_offsets[expert_id].item()
406
- end = expert_offsets[expert_id + 1].item()
407
- if start == end:
408
- continue
409
- expert_input = sorted_x[start:end] # (n_tokens, dim)
410
- # Use stacked weights directly: w1[expert_id] is (hidden, dim)
411
- h1 = F.linear(expert_input, self.w1_stacked[expert_id]) # (n, hidden)
412
- h3 = F.linear(expert_input, self.w3_stacked[expert_id]) # (n, hidden)
413
- h = F.gelu(h1) * h3
414
- sorted_output[start:end] = F.linear(h, self.w2_stacked[expert_id]) # (n, dim)
415
-
416
- # Weight by router probabilities
417
- sorted_output = sorted_output * sorted_probs.unsqueeze(-1)
418
-
419
- # Scatter back to original token positions
420
  output = torch.zeros_like(x_flat)
421
- output.scatter_add_(0, sorted_token_ids.unsqueeze(-1).expand_as(sorted_output), sorted_output)
422
 
423
- # Shared experts (for DeepSeekMoE) — unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  for shared_expert_id in range(self.moe_shared_experts):
425
  w1, w2, w3 = self.shared_experts[shared_expert_id]
426
  expert_output = w2(F.gelu(w1(x_flat)) * w3(x_flat))
427
  output = output + expert_output
428
-
429
  return output
430
 
431
 
432
  class Block(nn.Module):
433
- """Transformer block operating on unpadded (total_nnz, dim) tensors.
434
-
435
- Receives unpadding metadata (cu_seqlens, max_seqlen, indices, attn_mask)
436
- from the Transformer level and passes them to attention. Norms and FFN
437
- operate directly on the 2D unpadded tensor, avoiding wasted compute on
438
- padding tokens.
439
- """
440
  def __init__(self, config, layer_id: Optional[int] = None):
441
  super().__init__()
442
  self.is_first_block = (layer_id == 0)
@@ -447,53 +421,32 @@ class Block(nn.Module):
447
  else:
448
  self.ffn = FeedForward(config)
449
 
450
- self.norm_attention = LayerNormClass(config.num_dims, eps=config.layernorm_eps)
451
- self.norm_ffn = LayerNormClass(config.num_dims, eps=config.layernorm_eps)
452
 
453
- def forward(self, x, cu_seqlens, max_seqlen, indices, attn_mask):
454
- """
455
- Args:
456
- x: (total_nnz, dim) - unpadded hidden states
457
- cu_seqlens: (batch + 1,)
458
- max_seqlen: int
459
- indices: (total_nnz,)
460
- attn_mask: (batch, seq_len)
461
-
462
- Returns:
463
- x: (total_nnz, dim)
464
- aux_loss: auxiliary loss from MoE or None
465
- """
466
  if self.is_first_block:
467
  attn_in = x
468
  else:
469
  attn_in = self.norm_attention(x)
470
-
471
  x = x + self.attention(
472
- attn_in,
473
- cu_seqlens=cu_seqlens,
474
- max_seqlen=max_seqlen,
475
- indices=indices,
476
- attn_mask=attn_mask,
477
- )
478
-
479
  ffn_out, aux_loss = self.ffn(
480
  self.norm_ffn(x)
481
- )
482
  x = x + ffn_out
483
  return x, aux_loss
 
484
 
485
-
486
- # ==============================================================================
487
- # Core Transformer (nn.Module backbone used inside HF wrappers)
488
- # ==============================================================================
489
-
490
- class Transformer(nn.Module):
491
- """ModernBERT-style Transformer: unpad once before embeddings, repad once at
492
- the end. All blocks, norms, and FFNs operate on (total_nnz, dim) tensors,
493
- avoiding wasted compute on padding tokens.
494
- """
495
- def __init__(self, config):
496
  super().__init__()
 
 
497
 
498
  self.vocab_size = config.vocab_size
499
  self.num_dims = config.num_dims
@@ -503,112 +456,79 @@ class Transformer(nn.Module):
503
  self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe
504
 
505
  self.num_layers = config.num_layers
506
-
 
 
 
507
  hidden_dim = 4 * config.num_dims
 
 
 
 
508
 
509
- self.tokens_embedding = nn.Embedding(config.vocab_size, config.num_dims)
510
- self.norm_embeddings = LayerNormClass(config.num_dims, eps=config.layernorm_eps)
 
511
 
512
  self.blocks = nn.ModuleList()
513
  for layer_id in range(self.num_layers):
514
  self.blocks.append(Block(config, layer_id=layer_id))
515
 
516
- self.norm = LayerNormClass(config.num_dims, eps=config.layernorm_eps)
517
- self.ll_head = nn.Linear(config.num_dims, config.vocab_size, bias=False)
 
518
 
519
  self.tokens_embedding.weight = self.ll_head.weight
 
 
520
 
521
- def _unpad(self, input_ids, attention_mask):
522
- """Compute unpadding metadata and unpad input_ids before embedding.
523
 
524
- Unpads input_ids (cheap 1D integer indexing) so that embedding and
525
- all subsequent layers only process real tokens.
526
 
527
- Args:
528
- input_ids: (batch, seq_len)
529
- attention_mask: (batch, seq_len) or None
530
 
531
- Returns:
532
- input_ids_unpadded: (total_nnz,)
533
- indices: (total_nnz,)
534
- cu_seqlens: (batch + 1,)
535
- max_seqlen: int
536
- attn_mask: (batch, seq_len)
537
- batch_size: int
538
- seq_len: int
539
- """
540
- batch_size, seq_len = input_ids.shape
541
 
542
- if attention_mask is None:
543
- attn_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.int32)
544
- else:
545
- attn_mask = attention_mask.to(dtype=torch.int32)
546
-
547
- # Unpad input_ids using the same bert_padding logic but on (batch, seq_len, 1)
548
- # so we can reuse unpad_input which expects 3D
549
- input_ids_3d = input_ids.unsqueeze(-1).float() # (batch, seq_len, 1)
550
- input_ids_unpadded, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(input_ids_3d, attn_mask)
551
- input_ids_unpadded = input_ids_unpadded.squeeze(-1).long() # (total_nnz,)
552
-
553
- return input_ids_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len
554
-
555
- def forward(
556
- self,
557
- x: torch.Tensor,
558
- targets: Optional[torch.Tensor] = None,
559
- start_pos: int = 0,
560
- attention_mask: Optional[torch.Tensor] = None,
561
- ):
562
- batch_size, seq_len = x.shape
563
-
564
- # Unpad input_ids before embedding — only embed real tokens
565
- x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self._unpad(x, attention_mask)
566
-
567
- # Embed only real tokens (total_nnz, dim)
568
- x = self.tokens_embedding(x_unpadded)
569
  x = self.norm_embeddings(x)
570
-
 
 
 
 
571
  total_aux_loss = 0
572
 
573
  for block in self.blocks:
574
- x, aux_loss = block(
575
- x,
576
- cu_seqlens=cu_seqlens,
577
- max_seqlen=max_seqlen,
578
- indices=indices,
579
- attn_mask=attn_mask,
580
- )
581
  if self.use_moe and not self.use_lossfreebalance:
582
  total_aux_loss += aux_loss
583
-
584
  x = self.norm(x)
585
-
586
- # Repad once — back to (batch, seq_len, dim) for the LM head / loss
587
- x = bert_padding.pad_input(x, indices, batch_size, seq_len)
588
-
589
  logits = self.ll_head(x)
590
-
 
591
  if targets is None:
592
  loss = None
593
  ce_loss = None
594
  else:
595
  c_batch_size, c_context_len, c_dim = logits.shape
596
- logits = logits.view(c_batch_size * c_context_len, c_dim)
597
- targets = targets.view(c_batch_size * c_context_len)
598
  ce_loss = F.cross_entropy(logits, targets)
599
-
600
- if self.use_moe and not self.use_lossfreebalance:
601
- loss = ce_loss + total_aux_loss
602
- else:
603
  loss = ce_loss
604
  ce_loss = aux_loss
605
 
606
  return logits, loss, ce_loss
607
 
608
  @torch.no_grad()
609
- def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50,
610
  use_cache: bool = False):
611
- """Generate text from x up to max_tokens."""
 
 
612
  for c_tkn_pos in range(max_tokens):
613
  if use_cache:
614
  if c_tkn_pos == 0:
@@ -629,265 +549,48 @@ class Transformer(nn.Module):
629
  return x
630
 
631
 
632
- # ==============================================================================
633
- # HuggingFace PreTrainedModel Wrappers
634
- # ==============================================================================
635
-
636
- class CustomTransformerPreTrainedModel(PreTrainedModel):
637
- """Base class for CustomTransformer models."""
638
- config_class = CustomTransformerConfig
639
- base_model_prefix = "transformer"
640
- supports_gradient_checkpointing = False
641
- _no_split_modules = ["Block"]
642
-
643
- def _init_weights(self, module):
644
- """Initialize weights - handled by model itself."""
645
- pass
646
-
647
-
648
- class CustomTransformerModel(CustomTransformerPreTrainedModel):
649
- """The bare CustomTransformer Model outputting raw hidden-states."""
650
-
651
- def __init__(self, config: CustomTransformerConfig):
652
- super().__init__(config)
653
- self.config = config
654
-
655
- self.transformer = Transformer(config)
656
-
657
- self.post_init()
658
-
659
- def get_input_embeddings(self):
660
- return self.transformer.tokens_embedding
661
-
662
- def set_input_embeddings(self, value):
663
- self.transformer.tokens_embedding = value
664
-
665
- def forward(
666
- self,
667
- input_ids: Optional[torch.LongTensor] = None,
668
- attention_mask: Optional[torch.FloatTensor] = None,
669
- position_ids: Optional[torch.LongTensor] = None,
670
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
671
- inputs_embeds: Optional[torch.FloatTensor] = None,
672
- use_cache: Optional[bool] = None,
673
- output_attentions: Optional[bool] = None,
674
- output_hidden_states: Optional[bool] = None,
675
- return_dict: Optional[bool] = None,
676
- ) -> Union[Tuple, BaseModelOutputWithPast]:
677
- """Forward pass returning raw hidden states."""
678
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
679
-
680
- # Unpad input_ids before embedding
681
- x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask)
682
-
683
- # Embed only real tokens
684
- x = self.transformer.tokens_embedding(x_unpadded)
685
- x = self.transformer.norm_embeddings(x)
686
-
687
- for block in self.transformer.blocks:
688
- x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask)
689
-
690
- x = self.transformer.norm(x)
691
-
692
- # Repad once
693
- hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len)
694
-
695
- if not return_dict:
696
- return (hidden_states,)
697
-
698
- return BaseModelOutputWithPast(
699
- last_hidden_state=hidden_states,
700
- past_key_values=None,
701
- hidden_states=None,
702
- attentions=None,
703
- )
704
-
705
-
706
- class CustomTransformerForMaskedLM(CustomTransformerPreTrainedModel):
707
- """CustomTransformer Model with a masked language modeling head on top."""
708
- _tied_weights_keys = ["transformer.ll_head.weight", "transformer.tokens_embedding.weight"]
709
-
710
- def __init__(self, config: CustomTransformerConfig):
711
- super().__init__(config)
712
- self.config = config
713
-
714
- self.transformer = Transformer(config)
715
-
716
- self.post_init()
717
-
718
- def get_input_embeddings(self):
719
- return self.transformer.tokens_embedding
720
-
721
- def set_input_embeddings(self, value):
722
- self.transformer.tokens_embedding = value
723
-
724
- def get_output_embeddings(self):
725
- return self.transformer.ll_head
726
-
727
- def set_output_embeddings(self, new_embeddings):
728
- self.transformer.ll_head = new_embeddings
729
-
730
- def forward(
731
- self,
732
- input_ids: Optional[torch.LongTensor] = None,
733
- attention_mask: Optional[torch.FloatTensor] = None,
734
- position_ids: Optional[torch.LongTensor] = None,
735
- head_mask: Optional[torch.FloatTensor] = None,
736
- inputs_embeds: Optional[torch.FloatTensor] = None,
737
- labels: Optional[torch.LongTensor] = None,
738
- output_attentions: Optional[bool] = None,
739
- output_hidden_states: Optional[bool] = None,
740
- return_dict: Optional[bool] = None,
741
- ) -> Union[Tuple, MaskedLMOutput]:
742
- """Forward pass for masked language modeling."""
743
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
744
-
745
- logits, model_loss, ce_loss = self.transformer(
746
- input_ids, targets=labels, start_pos=0, attention_mask=attention_mask
747
- )
748
-
749
- masked_lm_loss = None
750
- if labels is not None:
751
- masked_lm_loss = model_loss
752
-
753
- if not return_dict:
754
- output = (logits,)
755
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
756
-
757
- return MaskedLMOutput(
758
- loss=masked_lm_loss,
759
- logits=logits,
760
- hidden_states=None,
761
- attentions=None,
762
- )
763
-
764
-
765
- class CustomTransformerForSequenceClassification(CustomTransformerPreTrainedModel):
766
- """CustomTransformer Model with a sequence classification head on top."""
767
-
768
- def __init__(self, config: CustomTransformerConfig):
769
- super().__init__(config)
770
- self.num_labels = config.num_labels
771
- self.config = config
772
-
773
- self.transformer = Transformer(config)
774
-
775
- # Classification head
776
- classifier_dropout = (
777
- config.classifier_dropout
778
- if config.classifier_dropout is not None
779
- else config.attention_probs_dropout_prob
780
- )
781
- self.dropout = nn.Dropout(classifier_dropout)
782
- self.classifier = nn.Linear(config.num_dims, config.num_labels)
783
-
784
- self._init_classifier_weights()
785
- self.post_init()
786
-
787
- def _init_classifier_weights(self):
788
- std = 0.02
789
- if isinstance(self.classifier, nn.Linear):
790
- self.classifier.weight.data.normal_(mean=0.0, std=std)
791
- if self.classifier.bias is not None:
792
- self.classifier.bias.data.zero_()
793
-
794
- def forward(
795
- self,
796
- input_ids: Optional[torch.LongTensor] = None,
797
- attention_mask: Optional[torch.FloatTensor] = None,
798
- position_ids: Optional[torch.LongTensor] = None,
799
- head_mask: Optional[torch.FloatTensor] = None,
800
- inputs_embeds: Optional[torch.FloatTensor] = None,
801
- labels: Optional[torch.LongTensor] = None,
802
- output_attentions: Optional[bool] = None,
803
- output_hidden_states: Optional[bool] = None,
804
- return_dict: Optional[bool] = None,
805
- ) -> Union[Tuple, SequenceClassifierOutput]:
806
- """Forward pass for sequence classification."""
807
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
809
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
810
-
811
- # Unpad input_ids before embedding
812
- x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask)
813
-
814
- # Embed only real tokens
815
- x = self.transformer.tokens_embedding(x_unpadded)
816
- x = self.transformer.norm_embeddings(x)
817
-
818
- # Collect hidden states if requested (repad each for the output tuple)
819
- all_hidden_states = () if output_hidden_states else None
820
-
821
- if output_hidden_states:
822
- all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),)
823
-
824
- for block in self.transformer.blocks:
825
- x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask)
826
-
827
- if output_hidden_states:
828
- all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),)
829
-
830
- x = self.transformer.norm(x)
831
-
832
- # Repad once
833
- hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len)
834
-
835
- # Use [CLS] token representation (first token) for classification
836
- pooled_output = hidden_states[:, 0, :]
837
- pooled_output = self.dropout(pooled_output)
838
- logits = self.classifier(pooled_output)
839
-
840
- loss = None
841
- if labels is not None:
842
- if self.config.problem_type is None:
843
- if self.num_labels == 1:
844
- self.config.problem_type = "regression"
845
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
846
- self.config.problem_type = "single_label_classification"
847
- else:
848
- self.config.problem_type = "multi_label_classification"
849
-
850
- if self.config.problem_type == "regression":
851
- loss_fct = nn.MSELoss()
852
- if self.num_labels == 1:
853
- loss = loss_fct(logits.squeeze(), labels.squeeze())
854
- else:
855
- loss = loss_fct(logits, labels)
856
- elif self.config.problem_type == "single_label_classification":
857
- loss_fct = nn.CrossEntropyLoss()
858
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
859
- elif self.config.problem_type == "multi_label_classification":
860
- loss_fct = nn.BCEWithLogitsLoss()
861
- loss = loss_fct(logits, labels)
862
-
863
- if not return_dict:
864
- output = (logits,) + (all_hidden_states,) + (None,)
865
- return ((loss,) + output) if loss is not None else output
866
-
867
- return SequenceClassifierOutput(
868
- loss=loss,
869
- logits=logits,
870
- hidden_states=all_hidden_states,
871
- attentions=None,
872
- )
873
-
874
-
875
- # ==============================================================================
876
- # Auto-registration
877
- # ==============================================================================
878
-
879
- try:
880
- from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForSequenceClassification
881
-
882
- AutoConfig.register("custom_transformer", CustomTransformerConfig)
883
- AutoModel.register(CustomTransformerConfig, CustomTransformerModel)
884
- AutoModelForMaskedLM.register(CustomTransformerConfig, CustomTransformerForMaskedLM)
885
- AutoModelForSequenceClassification.register(CustomTransformerConfig, CustomTransformerForSequenceClassification)
886
- except Exception:
887
- pass
888
-
889
-
890
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
891
  pass
892
 
893
 
 
9
  import torch._dynamo
10
  torch._dynamo.config.capture_scalar_outputs = True
11
 
12
+ import numpy as np
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
+ import random
16
+ import time
17
+ import math
18
+ import inspect
19
+ import os
20
  from dataclasses import dataclass
21
+ from huggingface_hub import PyTorchModelHubMixin
22
+ from typing import Optional
 
 
23
 
24
  import bert_padding
25
  from attention import FlexBertUnpadRopeAttention
 
28
  from torch.nn.parallel import DistributedDataParallel as DDP
29
  import torch.distributed as dist
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
  @dataclass
33
  class ModelConfig:
34
  vocab_size: int
35
 
36
+ num_dims: int # number of dimensions
37
+ num_heads: int # number of query heads
38
+ num_kv_heads: int # number of key/value heads
39
+ num_layers: int # total transformer layers
40
+ ffn_hidden_dims: int # hidden dimension for FFN/FFNwMoE
41
 
42
+ context_len: int # maximum context length
43
+ use_cache: bool # enable KV-caching
44
+ use_flash: bool # use Flash Attention
45
+ use_moe: bool # enable mixture-of-experts
46
 
47
+ moe_num_experts: int # total number of experts
48
+ moe_routed_experts: int # number of experts per token (top_k)
49
+ moe_eps: float = 1e-6 # epsilon for router stability
50
+ moe_aux_loss_coef: float = 0.01 # coefficient for auxiliary loss
51
+ moe_shared_experts: int = 0 # number of shared experts (DeepSeekMoE)
52
+ use_lossfreebalance: bool = False # use Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664
53
 
54
  layernorm_eps: float = 1e-6
55
  rope_theta: float = 1e5
 
73
  num_attention_heads: Optional[int] = None
74
  embedding_size: Optional[int] = None
75
 
76
+ ffn_dim_multiplier: Optional[int] = None # optional multiplier to compute ffn_hidden_dims
77
 
78
  def __post_init__(self):
79
  if self.hidden_size is None:
 
86
  self.use_fa2 = self.use_flash
87
 
88
 
 
 
 
89
 
90
+ # Helper function for RoPE
91
+ def repeat_kv(vct: torch.Tensor, n_times: int):
92
+ c_batch_size, c_context_len, num_kv_heads, c_dim = vct.shape
93
+ if n_times == 1:
94
+ return vct
95
+ else:
96
+ return (
97
+ vct[:, :, :, None, :]
98
+ .expand(c_batch_size, c_context_len, num_kv_heads, n_times, c_dim)
99
+ .reshape(c_batch_size, c_context_len, num_kv_heads * n_times, c_dim)
100
+ )
101
 
102
+
103
+ class Rotary(nn.Module):
104
+ def __init__(self, config):
105
+ super(Rotary, self).__init__()
106
+
107
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, config.num_dims // config.num_heads, 2).float() / (config.num_dims // config.num_heads)))
108
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
109
+ self.seq_len_saved = None
110
+ self.cos_saved = None
111
+ self.sin_saved = None
112
+
113
+ def forward(self, x, seq_dim=1):
114
+ seq_len = x.size(seq_dim)
115
+ # Only recompute the cosine and sine matrices if the sequence length has changed.
116
+ if seq_len != self.seq_len_saved:
117
+ self.seq_len_saved = seq_len
118
+ pos = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
119
+ # Compute the outer product between positions and inverse frequencies.
120
+ freqs = torch.einsum("i,j->ij", pos, self.inv_freq) # (seq_len, inv_freq.shape[0])
121
+ # Duplicate the freqs along the last dimension to create pairs.
122
+ emb = torch.cat((freqs, freqs), dim=-1)
123
+ self.cos_saved = emb.cos()
124
+ self.sin_saved = emb.sin()
125
+
126
+ return self.cos_saved, self.sin_saved
127
+
128
+
129
+ class Layernorm(torch.nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.g = nn.Parameter(torch.ones(config.num_dims))
133
+ self.eps = config.layernorm_eps
134
+
135
+ def _norm(self, x):
136
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
137
+
138
+ def forward(self, x):
139
+ return self.g * self._norm(x.float()).type_as(x)
140
+
141
+
142
+ class GroupedQueryAttention(nn.Module):
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ self.config = config
146
+ self.use_cache = config.use_cache
147
+ self.use_flash = config.use_flash
148
+
149
+ self.num_heads = config.num_heads
150
+ self.num_kv_heads = config.num_heads if config.num_kv_heads is None else config.num_kv_heads
151
+
152
+ self.num_rep = self.num_heads // self.num_kv_heads
153
+ self.head_dim = config.num_dims // self.num_heads
154
+
155
+ self.wq = nn.Linear(config.num_dims, config.num_dims, bias=False)
156
+ nn.init.normal_(self.wq.weight, mean=0, std=1/math.sqrt(config.num_dims))
157
+ self.wk = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
158
+ nn.init.normal_(self.wk.weight, mean=0, std=1/math.sqrt(config.num_dims))
159
+ self.wv = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
160
+ nn.init.normal_(self.wv.weight, mean=0, std=1/math.sqrt(config.num_dims))
161
+
162
+ self.wo = nn.Linear(config.num_dims, config.num_dims, bias=False)
163
+
164
+ self.cache_k = None
165
+ self.cache_v = None
166
+
167
+
168
+ def rotate_half(self, x):
169
+ half = x.shape[-1] // 2
170
+ first_half, second_half = x[..., :half], x[..., half:]
171
+ return torch.cat([-second_half, first_half], dim=-1)
172
+
173
+
174
+ def apply_rotary_pos(self, q, k, cos, sin):
175
+ q_rot = q * cos + self.rotate_half(q) * sin
176
+ k_rot = k * cos + self.rotate_half(k) * sin
177
+ return q_rot, k_rot
178
+
179
+ def update_kv_cache(self, batch_size, start_pos, context_len, keys, values, device):
180
+ # Initialize cache if not exist
181
+ if self.cache_k is None:
182
+ self.cache_k = torch.zeros(
183
+ (batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
184
+ device=device
185
+ )
186
+ self.cache_v = torch.zeros(
187
+ (batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
188
+ device=device
189
+ )
190
+
191
+ # Update cache
192
+ self.cache_k[:batch_size, start_pos:start_pos + context_len] = keys
193
+ self.cache_v[:batch_size, start_pos:start_pos + context_len] = values
194
+
195
+ return (self.cache_k[:batch_size, :start_pos + context_len],
196
+ self.cache_v[:batch_size, :start_pos + context_len])
197
+
198
+
199
+ def forward(self, x, cos, sin, start_pos = 0):
200
+ c_batch_size, c_context_len, c_dim = x.shape # c_context_len = 1
201
+
202
+ if self.use_cache and c_context_len == 1:
203
+ # Cache branch
204
+ q = self.wq(x[:, -1, :])
205
+ k = self.wk(x[:, -1, :])
206
+ v = self.wv(x[:, -1, :])
207
+
208
+ q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) # B, T, qh, hs
209
+ k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, T, kh, hs
210
+ v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, T, vh, hs
211
+
212
+ # freqs_complex = freqs_complex[-1:]
213
+ # queries = apply_rotary_pos(q, freqs_complex, device=x.device)
214
+ # keys = apply_rotary_pos(k, freqs_complex, device=x.device)
215
+
216
+ keys, v = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys, values=v, device=x.device)
217
+ queries, keys = self.apply_rotary_pos(q, k, cos, sin)
218
+
219
+ else:
220
+ # Non-cache branch (process the entire sequence normally)
221
+ q = self.wq(x)
222
+ k = self.wk(x)
223
+ v = self.wv(x)
224
+
225
+ q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) # B, qh, T, hs
226
+ k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, kh, T, hs
227
+ v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, vh, T, hs
228
+
229
+ queries, keys = self.apply_rotary_pos(q, k, cos, sin)
230
+
231
+ # queries = apply_rotary_pos(q, freqs_complex, device=x.device)
232
+ # keys = apply_rotary_pos(k, freqs_complex, device=x.device)
233
+
234
+ if self.use_cache: _k, _v = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys, values=v, device=x.device)
235
+
236
+ if self.use_flash:
237
+ output = F.scaled_dot_product_attention(queries, keys, v, is_causal=True, enable_gqa=True)
238
+
239
+ else: # Calculate Grouped Query Attention manually
240
+ keys = repeat_kv(keys, self.num_rep)
241
+ values = repeat_kv(v, self.num_rep)
242
+
243
+ attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
244
+
245
+ if self.use_cache and x.shape[1] == 1:
246
+ total_length = keys.size(2)
247
+ # For autoregressive generation, the query (which is at the latest position) should only attend to keys at indices <= current token.
248
+ # Create a mask: allowed positions are indices < total_length (i.e. all in the cache)
249
+ mask = torch.arange(total_length, device=attention.device).unsqueeze(0) <= (start_pos + x.shape[1] - 1)
250
+ mask = mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, 1, total_length)
251
+ attention = attention.masked_fill(~mask, float("-inf"))
252
+ attention = F.softmax(attention, dim=-1)
253
+ output = torch.matmul(attention, values)
254
+
255
+ else: # Do not use kv_cache
256
+ attention = torch.tril(attention[:, :, :c_context_len, :c_context_len])
257
+ attention = attention.masked_fill(attention == 0, float("-inf"))
258
+
259
+ attention = F.softmax(attention, dim=-1).type_as(queries)
260
+ output = torch.matmul(attention, values)
261
+
262
+ output = output.transpose(2, 1).contiguous().view(c_batch_size, c_context_len, c_dim)
263
+ return self.wo(output)
264
+
265
+
266
+ class FlexBertUnpadAttention(nn.Module):
267
  def __init__(self, config, layer_id: Optional[int] = None):
268
  super().__init__()
269
  self.attn = FlexBertUnpadRopeAttention(config=config, layer_id=layer_id)
270
 
271
+ def forward(self, x: torch.Tensor):
272
+ batch_size, seq_len, _ = x.shape
273
+ attn_mask = torch.ones((batch_size, seq_len), device=x.device, dtype=torch.int32)
274
+ hidden_states, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(x, attn_mask)
275
+ attn_out = self.attn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  hidden_states=hidden_states,
277
  cu_seqlens=cu_seqlens,
278
  max_seqlen=max_seqlen,
279
  indices=indices,
280
  attn_mask=attn_mask,
281
  )
282
+ return bert_padding.pad_input(attn_out, indices, batch_size, seq_len)
283
 
284
 
285
  class FeedForward(nn.Module):
286
+ """
287
+ Default Feed Forward Layer.
288
+ """
289
  def __init__(self, config):
290
  super().__init__()
291
 
 
295
  self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False)
296
  self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)
297
  self.act = nn.GELU()
 
298
  def forward(self, x: torch.Tensor):
299
  return self.w2(self.act(self.w1(x)) * self.w3(x)), None
300
 
301
 
302
+ class FFNwMoE(nn.Module):
303
  """
304
  Feed Forward with MoE with optional shared experts.
 
 
 
 
 
 
 
305
  Returns after forward:
306
  output: Combined outputs from experts
307
  aux_loss: Auxiliary loss tensor or routing metadata
308
  """
309
+ def __init__(self, config: ModelConfig):
310
  super().__init__()
311
  self.hidden_dim = config.ffn_hidden_dims
 
312
 
313
+ self.moe_routed_experts = config.moe_routed_experts # top_k
314
  self.moe_aux_loss_coef = config.moe_aux_loss_coef
315
  self.moe_eps = config.moe_eps
316
  self.moe_shared_experts = config.moe_shared_experts
317
  self.num_experts = config.moe_num_experts
318
 
319
+ self.use_lossfreebalance = config.use_lossfreebalance
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ self.router = nn.Linear(config.num_dims, self.num_experts, bias=False)
323
+ self.experts = nn.ModuleList()
324
+ for _ in range(self.num_experts):
325
+ self.experts.append(
326
+ nn.ModuleList([
327
+ nn.Linear(config.num_dims, self.hidden_dim, bias=False),
328
+ nn.Linear(self.hidden_dim, config.num_dims, bias=False),
329
+ nn.Linear(config.num_dims, self.hidden_dim, bias=False)
330
+ ]))
331
+
332
  # shared experts (for DeepSeekMoE)
333
  self.shared_experts = nn.ModuleList()
334
  for _ in range(self.moe_shared_experts):
 
338
  nn.Linear(self.hidden_dim, config.num_dims, bias=False),
339
  nn.Linear(config.num_dims, self.hidden_dim, bias=False)
340
  ]))
341
+
342
+ # Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664
343
  if self.use_lossfreebalance:
344
  self.expert_biases = nn.Parameter(torch.zeros(self.num_experts))
345
+
346
  def forward(self, x: torch.Tensor):
347
+ c_batch_size, c_context_len, c_dim = x.shape
348
+ x_flat = x.view(-1, c_dim) #c_batch_size * c_context_len, c_dim
 
 
 
 
 
 
349
 
350
  router_out = self.router(x_flat)
351
+ router_probs = F.softmax(router_out, dim=-1)
352
 
353
  _, topk_indices = router_out.topk(self.moe_routed_experts, dim=-1)
354
  self.last_topk_indices = topk_indices.detach()
 
357
 
358
  output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs)
359
 
360
+ return output.view(c_batch_size, c_context_len, c_dim), aux_loss
 
 
 
361
 
362
  def _compute_aux_loss(self, router_out, router_probs, topk_indices):
363
+ """
364
+ Computes the auxiliary loss based on whether loss-free balancing is used or not.
365
+ """
366
  if not self.use_lossfreebalance:
367
  topk_probs, _ = router_probs.topk(self.moe_routed_experts, dim=-1)
368
  expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float()
 
370
  router_prob_mean = router_probs.mean(dim=0)
371
  aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts
372
 
373
+ else: # if use_lossfreebalance
374
  router_out = router_out + self.expert_biases
375
+ router_probs = torch.sigmoid(router_out) # from https://arxiv.org/pdf/2408.15664 paper
376
  topk_probs = router_probs.gather(-1, topk_indices)
377
  topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
378
 
379
+ # In the case of Auxiliary-loss-free load balancing we pass router_probs, topk_indices as aux_loss for further calculations
380
  aux_loss = (router_probs, topk_indices)
381
  return aux_loss, topk_probs
382
 
383
  def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs):
 
 
 
 
 
384
  """
385
+ Compute the output of the experts and shared experts if needed
386
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  output = torch.zeros_like(x_flat)
 
388
 
389
+ for i in range(self.moe_routed_experts):
390
+ expert_index = topk_indices[:, i]
391
+ expert_probs = topk_probs[:, i]
392
+
393
+ for expert_id in range(self.num_experts):
394
+ idx = (expert_id == expert_index).nonzero().squeeze()
395
+
396
+ if idx.numel() == 0:
397
+ continue
398
+ x_for_expert = x_flat[idx]
399
+ w1, w2, w3 = self.experts[expert_id]
400
+
401
+ expert_output = w2(F.gelu(w1(x_for_expert)) * w3(x_for_expert))
402
+ output[idx] += expert_output * expert_probs[idx].unsqueeze(-1)
403
+
404
+ # shared experts(for DeepSeekMoE)
405
  for shared_expert_id in range(self.moe_shared_experts):
406
  w1, w2, w3 = self.shared_experts[shared_expert_id]
407
  expert_output = w2(F.gelu(w1(x_flat)) * w3(x_flat))
408
  output = output + expert_output
409
+
410
  return output
411
 
412
 
413
  class Block(nn.Module):
 
 
 
 
 
 
 
414
  def __init__(self, config, layer_id: Optional[int] = None):
415
  super().__init__()
416
  self.is_first_block = (layer_id == 0)
 
421
  else:
422
  self.ffn = FeedForward(config)
423
 
 
 
424
 
425
+ self.norm_attention = nn.LayerNorm(config.num_dims, eps=config.layernorm_eps)
426
+ self.norm_ffn = nn.LayerNorm(config.num_dims, eps=config.layernorm_eps)
427
+
428
+ def forward(self, x, start_pos):
429
+ _ = start_pos
 
 
 
 
 
 
 
 
430
  if self.is_first_block:
431
  attn_in = x
432
  else:
433
  attn_in = self.norm_attention(x)
 
434
  x = x + self.attention(
435
+ attn_in
436
+ )
437
+
 
 
 
 
438
  ffn_out, aux_loss = self.ffn(
439
  self.norm_ffn(x)
440
+ )
441
  x = x + ffn_out
442
  return x, aux_loss
443
+
444
 
445
+ class Transformer(nn.Module, PyTorchModelHubMixin): # extending PyTorchModelHubMixin for save weights as safetensors
446
+ def __init__(self, config: ModelConfig):
 
 
 
 
 
 
 
 
 
447
  super().__init__()
448
+ if isinstance(config, dict):
449
+ config = ModelConfig(**config)
450
 
451
  self.vocab_size = config.vocab_size
452
  self.num_dims = config.num_dims
 
456
  self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe
457
 
458
  self.num_layers = config.num_layers
459
+
460
+ # Calculation of hidden_dim for FFN/FFNwMoE
461
+ # multiple_of = 4
462
+ # ffn_dim_multiplier = config.ffn_dim_multiplier
463
  hidden_dim = 4 * config.num_dims
464
+ # hidden_dim = int(2 * config.num_dims / 3)
465
+
466
+ # if ffn_dim_multiplier is not None:
467
+ # hidden_dim = int(ffn_dim_multiplier * hidden_dim)
468
 
469
+ # config.ffn_hidden_dims = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
470
+ self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims)
471
+ self.norm_embeddings = nn.LayerNorm(config.num_dims, eps=config.layernorm_eps)
472
 
473
  self.blocks = nn.ModuleList()
474
  for layer_id in range(self.num_layers):
475
  self.blocks.append(Block(config, layer_id=layer_id))
476
 
477
+ self.norm = nn.LayerNorm(config.num_dims, eps=config.layernorm_eps)
478
+ self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False)
479
+
480
 
481
  self.tokens_embedding.weight = self.ll_head.weight
482
+ # torch.nn.init.normal_(self.ll_head.weight, mean=0.0, std=0.02)
483
+ # torch.nn.init.normal_(self.tokens_embedding.weight, mean=0.0, std=0.02)
484
 
485
+ # self.freqs_complex = None # precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=config.device)
 
486
 
 
 
487
 
 
 
 
488
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None, start_pos: int = 0):
491
+ x = self.tokens_embedding(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  x = self.norm_embeddings(x)
493
+
494
+ # if self.freqs_complex == None:
495
+ # self.freqs_complex = precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=x.device)
496
+ # freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
497
+
498
  total_aux_loss = 0
499
 
500
  for block in self.blocks:
501
+ x, aux_loss = block(x, start_pos=start_pos)
 
 
 
 
 
 
502
  if self.use_moe and not self.use_lossfreebalance:
503
  total_aux_loss += aux_loss
504
+
505
  x = self.norm(x)
 
 
 
 
506
  logits = self.ll_head(x)
507
+
508
+
509
  if targets is None:
510
  loss = None
511
  ce_loss = None
512
  else:
513
  c_batch_size, c_context_len, c_dim = logits.shape
514
+ logits = logits.view(c_batch_size*c_context_len, c_dim)
515
+ targets = targets.view(c_batch_size*c_context_len)
516
  ce_loss = F.cross_entropy(logits, targets)
517
+
518
+ if self.use_moe and not self.use_lossfreebalance: loss = ce_loss + total_aux_loss # in this case, ce_loss its loss w/o aux_loss
519
+ else: # if we want to use Auxiliary-loss-free load balancing we pass router_probs, topk_indices as ce_loss
520
+ # Also, work when moe is not used
521
  loss = ce_loss
522
  ce_loss = aux_loss
523
 
524
  return logits, loss, ce_loss
525
 
526
  @torch.no_grad()
527
+ def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50,
528
  use_cache: bool = False):
529
+ """
530
+ Generate text from x up to max_tokens
531
+ """
532
  for c_tkn_pos in range(max_tokens):
533
  if use_cache:
534
  if c_tkn_pos == 0:
 
549
  return x
550
 
551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  def main():
553
+ # config = ModelConfig(
554
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu',
555
+ # vocab_size = 50304,
556
+
557
+ # num_dims = 1024,
558
+ # num_heads = 16,
559
+ # num_kv_heads = 4,
560
+ # num_layers = 16,
561
+ # ffn_hidden_dims = 1024 * 4,
562
+
563
+ # layernorm_eps = 1e-6,
564
+ # rope_theta = 1e5,
565
+
566
+ # context_len = 1024,
567
+
568
+ # use_cache = False,
569
+ # use_flash = False,
570
+ # use_moe = False,
571
+
572
+ # moe_num_experts = 6,
573
+ # moe_routed_experts = 1,
574
+ # moe_eps = 1e-6,
575
+ # moe_aux_loss_coef = 0.01,
576
+ # moe_shared_experts = 0,
577
+ # use_lossfreebalance = False,
578
+
579
+ # )
580
+
581
+
582
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
583
+ # SEED = 1337
584
+
585
+ # torch.manual_seed(SEED)
586
+ # if device == 'cuda':
587
+ # torch.cuda.manual_seed(SEED)
588
+
589
+ # model = Transformer(config)
590
+ # model = model.to(device)
591
+ # model = torch.compile(model)
592
+
593
+ # print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
594
  pass
595
 
596