SlimFactory commited on
Commit
e65ee65
·
verified ·
1 Parent(s): 7471f73

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adaptive_routing": true,
3
+ "architectures": [
4
+ "SlimMoEForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_slim_moe.SlimMoEConfig",
8
+ "AutoModelForCausalLM": "modeling_slim_moe.SlimMoEForCausalLM"
9
+ },
10
+ "bos_token_id": 2,
11
+ "dim": 768,
12
+ "dropout": 0.1,
13
+ "dtype": "float32",
14
+ "eos_token_id": 3,
15
+ "hidden_dim": 1536,
16
+ "max_seq_len": 2048,
17
+ "model_type": "slim_moe",
18
+ "num_experts": 4,
19
+ "num_heads": 12,
20
+ "num_hidden_layers": 16,
21
+ "pad_token_id": 0,
22
+ "transformers_version": "4.57.1",
23
+ "vocab_size": 50257
24
+ }
configuration_slim_moe.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SlimMoEConfig(PretrainedConfig):
5
+ model_type = "slim_moe"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size: int = 50257,
10
+ dim: int = 768,
11
+ num_hidden_layers: int = 12,
12
+ num_heads: int = 12,
13
+ hidden_dim: int = 2048,
14
+ num_experts: int = 4,
15
+ max_seq_len: int = 2048,
16
+ dropout: float = 0.1,
17
+ adaptive_routing: bool = True,
18
+ **kwargs
19
+ ):
20
+ self.vocab_size = vocab_size
21
+ self.dim = dim
22
+ self.num_hidden_layers = num_hidden_layers
23
+ self.num_heads = num_heads
24
+ self.hidden_dim = hidden_dim
25
+ self.num_experts = num_experts
26
+ self.max_seq_len = max_seq_len
27
+ self.dropout = dropout
28
+ self.adaptive_routing = adaptive_routing
29
+
30
+ # --- FIX: Enable automatic weight tying by the framework ---
31
+ # This tells the PreTrainedModel's post_init to handle the tie correctly.
32
+ self.tie_word_embeddings = True
33
+
34
+ super().__init__(**kwargs)
35
+
36
+ @classmethod
37
+ def for_250m(cls, vocab_size: int = 50257, max_seq_len: int = 2048, dropout: float = 0.1):
38
+ """
39
+ Create configuration for ~300M parameter model.
40
+ Uses: dim=768, layers=16, heads=12, hidden_dim=1536, experts=4
41
+ This yields approximately 280-290M parameters, safely under 250M.
42
+ """
43
+ return cls(
44
+ vocab_size=vocab_size,
45
+ dim=768,
46
+ num_hidden_layers=16,
47
+ num_heads=12,
48
+ hidden_dim=1536,
49
+ num_experts=4,
50
+ max_seq_len=max_seq_len,
51
+ dropout=dropout
52
+ )
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": [
5
+ 3
6
+ ],
7
+ "pad_token_id": 0,
8
+ "transformers_version": "4.57.1"
9
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0caf89a527e2dfec740efde1c03606b273fc5a04bc9b7381a771c3dc5fcd55a
3
+ size 1011802248
modeling_slim_moe.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, GenerationMixin
4
+ from transformers.utils import logging
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+
7
+ from .configuration_slim_moe import SlimMoEConfig
8
+ from .slim_moe_transformer import SlimMOETransformer
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ # AutoConfig.register('slim_moe', SlimMoEConfig)
13
+ # CONFIG_MAPPING.register("slim_moe", SlimMoEConfig)
14
+
15
+
16
+ class SlimMoEModel(PreTrainedModel):
17
+ config_class = SlimMoEConfig
18
+ base_model_prefix = "transformer"
19
+ supports_gradient_checkpointing = True
20
+ _no_split_modules = ["SlimMoETransformerBlock"]
21
+
22
+ def _init_weights(self, module):
23
+ std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
24
+ if isinstance(module, nn.Linear):
25
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
26
+ if module.bias is not None:
27
+ torch.nn.init.zeros_(module.bias)
28
+ elif isinstance(module, nn.Embedding):
29
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
30
+ elif isinstance(module, nn.LayerNorm):
31
+ torch.nn.init.zeros_(module.bias)
32
+ torch.nn.init.ones_(module.weight)
33
+
34
+ # MODEL_MAPPING.register(SlimMoEConfig, SlimMoEModel)
35
+
36
+ class SlimMoEForCausalLM(SlimMoEModel, GenerationMixin):
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+
40
+ self.transformer = SlimMOETransformer(
41
+ vocab_size=config.vocab_size,
42
+ dim=config.dim,
43
+ num_layers=config.num_hidden_layers,
44
+ num_heads=config.num_heads,
45
+ hidden_dim=config.hidden_dim,
46
+ num_experts=config.num_experts,
47
+ max_seq_len=config.max_seq_len,
48
+ dropout=config.dropout,
49
+ adaptive_routing=getattr(config, 'adaptive_routing', True)
50
+ )
51
+
52
+ # --- FIX: Define the lm_head at the top level of this model ---
53
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
54
+
55
+ # Initialize weights and apply final processing (including weight tying)
56
+ self.post_init()
57
+
58
+ self.lm_head.weight = self.transformer.token_embedding.weight
59
+
60
+ self._dynamic_tied_weights_keys = ['lm_head.weight', 'transformer.token_embedding.weight']
61
+
62
+ # Initialize aux_loss for logging
63
+ self.aux_loss = 0.0
64
+
65
+ # Auxiliary loss coefficient (can be modified after initialization)
66
+ self.aux_loss_coefficient = getattr(config, 'aux_loss_coefficient', 0.01)
67
+
68
+ @classmethod
69
+ def from_pretrained_with_tokenizer(cls, model_path: str, tokenizer_path: str = None):
70
+ """
71
+ Load model from pretrained and optionally use a custom tokenizer.
72
+
73
+ Args:
74
+ model_path: Path to the pretrained model
75
+ tokenizer_path: Path to custom tokenizer (if None, uses default)
76
+
77
+ Returns:
78
+ model, tokenizer tuple
79
+ """
80
+ from transformers import AutoTokenizer
81
+
82
+ model = cls.from_pretrained(model_path, trust_remote_code=True)
83
+
84
+ if tokenizer_path:
85
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
86
+ # Update vocab size if needed
87
+ if tokenizer.vocab_size != model.config.vocab_size:
88
+ print(f"Warning: Tokenizer vocab size ({tokenizer.vocab_size}) != "
89
+ f"model vocab size ({model.config.vocab_size})")
90
+ print(" Consider retraining model with matching vocab size")
91
+ else:
92
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
93
+
94
+ return model, tokenizer
95
+
96
+ def get_input_embeddings(self):
97
+ return self.transformer.token_embedding
98
+
99
+ def set_input_embeddings(self, value):
100
+ self.transformer.token_embedding = value
101
+
102
+ def get_output_embeddings(self):
103
+ # --- FIX: Return the top-level lm_head ---
104
+ return self.lm_head
105
+
106
+ def set_output_embeddings(self, new_embeddings):
107
+ # --- FIX: Set the top-level lm_head ---
108
+ self.lm_head = new_embeddings
109
+
110
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
111
+ # 1. Get hidden states from the base transformer
112
+ transformer_outputs = self.transformer(
113
+ input_ids=input_ids,
114
+ attention_mask=attention_mask
115
+ )
116
+ hidden_states = transformer_outputs['last_hidden_state']
117
+
118
+ # 2. Project hidden states to logits
119
+ logits = self.lm_head(hidden_states)
120
+
121
+ # 3. Calculate loss if labels are provided
122
+ loss = None
123
+ if labels is not None:
124
+ shift_logits = logits[..., :-1, :].contiguous()
125
+ shift_labels = labels[..., 1:].contiguous()
126
+
127
+ loss_fct = nn.CrossEntropyLoss()
128
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
129
+ shift_labels.view(-1))
130
+
131
+ # Add auxiliary loss from MOE layers
132
+ if self.training:
133
+ aux_loss = transformer_outputs['aux_loss']
134
+ # Store aux_loss for logging (accessible via model.aux_loss)
135
+ self.aux_loss = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
136
+ loss = loss + self.aux_loss_coefficient * aux_loss
137
+ else:
138
+ self.aux_loss = 0.0
139
+
140
+ return CausalLMOutputWithPast(
141
+ loss=loss,
142
+ logits=logits,
143
+ )
144
+
145
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
146
+ return {
147
+ "input_ids": input_ids,
148
+ "attention_mask": kwargs.get("attention_mask"),
149
+ }
150
+
151
+ # AutoModelForCausalLM.register(SlimMoEConfig, SlimMoEForCausalLM)
152
+
153
+ # MODEL_FOR_CAUSAL_LM_MAPPING.register(SlimMoEConfig, SlimMoEForCausalLM)
154
+
155
+
156
+ def create_moe_causal_lm(vocab_size: int = 50257):
157
+ """
158
+ Create a SlimMoEForCausalLM model with approximately 250M parameters.
159
+
160
+ Returns a full CausalLM model (not just the transformer) configured for ~250M params.
161
+ """
162
+ from .configuration_slim_moe import SlimMoEConfig
163
+
164
+ config = SlimMoEConfig.for_300m(vocab_size=vocab_size)
165
+ model = SlimMoEForCausalLM(config)
166
+
167
+ return model
168
+
169
+
slim_moe_transformer.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Optional, Tuple, List
6
+ import warnings
7
+
8
+
9
+ class RotaryPositionEmbedding(nn.Module):
10
+ """RoPE implementation without traditional position embeddings"""
11
+
12
+ def __init__(self, dim: int, base: int = 10000):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.base = base
16
+ # Only compute frequencies for half the dimensions (complex form)
17
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
18
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
19
+
20
+ def forward(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
21
+ seq_len = x.shape[seq_dim]
22
+ device = x.device
23
+ dtype = x.dtype
24
+
25
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
26
+ freqs = torch.outer(t, self.inv_freq)
27
+
28
+ # Create cosine and sine components
29
+ cos = torch.cos(freqs).to(dtype)
30
+ sin = torch.sin(freqs).to(dtype)
31
+
32
+ return cos, sin
33
+
34
+
35
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
36
+ """Apply rotary position embedding to input tensor"""
37
+ # x shape: [batch_size, num_heads, seq_len, head_dim]
38
+ # cos, sin shape: [seq_len, head_dim//2]
39
+
40
+ batch_size, num_heads, seq_len, head_dim = x.shape
41
+ half_dim = head_dim // 2
42
+
43
+ # Reshape x to separate real and imaginary parts
44
+ x_reshaped = x.view(batch_size, num_heads, seq_len, half_dim, 2)
45
+ x_real = x_reshaped[..., 0]
46
+ x_imag = x_reshaped[..., 1]
47
+
48
+ # Expand cos and sin to match dimensions
49
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim]
50
+ sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, half_dim]
51
+
52
+ # Apply rotation
53
+ x_real_rot = x_real * cos - x_imag * sin
54
+ x_imag_rot = x_real * sin + x_imag * cos
55
+
56
+ # Combine back
57
+ x_rotated = torch.stack([x_real_rot, x_imag_rot], dim=-1)
58
+ x_rotated = x_rotated.view(batch_size, num_heads, seq_len, head_dim)
59
+
60
+ return x_rotated.type_as(x)
61
+
62
+
63
+ class VariableGroupedQueryAttention(nn.Module):
64
+ """Variable Grouped Query Attention with layer-specific head grouping and optional RoPE/NoPE"""
65
+
66
+ def __init__(self, dim: int, num_heads: int = 8, layer_idx: int = 0,
67
+ num_layers: int = 12, variable_groups: bool = True,
68
+ use_rope: bool = True):
69
+ super().__init__()
70
+ self.dim = dim
71
+ self.num_heads = num_heads
72
+ self.head_dim = dim // num_heads
73
+ self.scale = self.head_dim ** -0.5
74
+ self.variable_groups = variable_groups
75
+ self.layer_idx = layer_idx
76
+ self.num_layers = num_layers
77
+ self.use_rope = use_rope
78
+
79
+ # Variable group calculation - different KV heads for each layer
80
+ if variable_groups:
81
+ # Create progressive pattern: more KV heads in deeper layers
82
+ # Early layers: fewer KV heads (more compression)
83
+ # Later layers: more KV heads (more detail)
84
+
85
+ # Normalized layer position (0 to 1)
86
+ layer_ratio = layer_idx / max(1, num_layers - 1)
87
+
88
+ # Calculate KV heads with progressive scaling
89
+ # Start with fewer KV heads (e.g., 2-3) and increase toward end
90
+ min_kv_heads = max(1, num_heads // 6) # Minimum 1/6 of heads
91
+ max_kv_heads = max(2, num_heads // 3) # Maximum 1/3 of heads
92
+
93
+ # Progressive scaling: early layers use fewer, later use more
94
+ raw_kv_heads = int(min_kv_heads + (max_kv_heads - min_kv_heads) * layer_ratio)
95
+
96
+ # Ensure it's a valid divisor
97
+ self.num_kv_heads = raw_kv_heads
98
+ if self.num_heads % self.num_kv_heads != 0:
99
+ # Find the nearest valid num_kv_heads
100
+ for i in range(self.num_kv_heads, 0, -1):
101
+ if self.num_heads % i == 0:
102
+ self.num_kv_heads = i
103
+ break
104
+ # If that didn't work, try going up
105
+ if self.num_heads % self.num_kv_heads != 0:
106
+ for i in range(self.num_kv_heads + 1, max_kv_heads + 1):
107
+ if self.num_heads % i == 0:
108
+ self.num_kv_heads = i
109
+ break
110
+ else:
111
+ self.num_kv_heads = max(2, num_heads // 2)
112
+
113
+ # Final validation
114
+ assert self.num_heads % self.num_kv_heads == 0, \
115
+ f"Layer {layer_idx}: num_heads ({num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
116
+
117
+ # Query projections
118
+ self.q_proj = nn.Linear(dim, dim, bias=False)
119
+
120
+ # Key-Value projections with grouped attention
121
+ self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
123
+
124
+ # Output projection
125
+ self.out_proj = nn.Linear(dim, dim, bias=False)
126
+
127
+ # RoPE - only create if using positional embeddings
128
+ # NoPE layers (every 4th layer) skip positional embeddings entirely
129
+ if self.use_rope:
130
+ self.rope = RotaryPositionEmbedding(self.head_dim)
131
+ else:
132
+ self.rope = None
133
+
134
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
135
+ batch_size, seq_len, _ = x.shape
136
+
137
+ # Project queries, keys, values
138
+ q = self.q_proj(x) # [batch, seq_len, dim]
139
+ k = self.k_proj(x) # [batch, seq_len, num_kv_heads * head_dim]
140
+ v = self.v_proj(x) # [batch, seq_len, num_kv_heads * head_dim]
141
+
142
+ # Reshape for multi-head attention
143
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
144
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
145
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
146
+
147
+ # Apply RoPE to queries and keys (NoPE layers skip this)
148
+ # NoPE layers rely on causal attention mask for positional information
149
+ if self.use_rope and self.rope is not None:
150
+ cos, sin = self.rope(q)
151
+ q = apply_rotary_pos_emb(q, cos, sin)
152
+ k = apply_rotary_pos_emb(k, cos, sin)
153
+ # else: NoPE - no positional embeddings applied, causal mask provides ordering
154
+
155
+ # Expand KV heads for grouped query attention
156
+ if self.num_kv_heads != self.num_heads:
157
+ k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
158
+ v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
159
+
160
+ # Compute attention scores
161
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
162
+
163
+ # Apply attention mask if provided
164
+ if attention_mask is not None:
165
+ attn_scores = attn_scores + attention_mask
166
+
167
+ attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
168
+
169
+ # Apply attention to values
170
+ attn_output = torch.matmul(attn_weights, v)
171
+
172
+ # Reshape and project back
173
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
174
+ batch_size, seq_len, self.dim
175
+ )
176
+
177
+ return self.out_proj(attn_output)
178
+
179
+
180
+ class Expert(nn.Module):
181
+ """Single expert in the MOE layer"""
182
+
183
+ def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
184
+ super().__init__()
185
+ self.net = nn.Sequential(
186
+ nn.Linear(dim, hidden_dim),
187
+ nn.GELU(),
188
+ nn.Dropout(dropout),
189
+ nn.Linear(hidden_dim, dim),
190
+ nn.Dropout(dropout)
191
+ )
192
+
193
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
194
+ return self.net(x)
195
+
196
+
197
+ class MOELayer(nn.Module):
198
+ """Mixture of Experts Layer with adaptive routing based on input complexity"""
199
+
200
+ def __init__(self, dim: int, hidden_dim: int, num_experts: int = 4,
201
+ capacity_factor: float = 1.0, noisy_gating: bool = True,
202
+ adaptive_routing: bool = True):
203
+ super().__init__()
204
+ self.dim = dim
205
+ self.num_experts = num_experts
206
+ self.capacity_factor = capacity_factor
207
+ self.noisy_gating = noisy_gating
208
+ self.adaptive_routing = adaptive_routing
209
+
210
+ # Create experts
211
+ self.experts = nn.ModuleList([
212
+ Expert(dim, hidden_dim) for _ in range(num_experts)
213
+ ])
214
+
215
+ # Standard gate network
216
+ self.gate = nn.Linear(dim, num_experts)
217
+
218
+ # NOVEL: Adaptive complexity-based routing
219
+ # Learns to route tokens based on their complexity/importance
220
+ if adaptive_routing:
221
+ # Complexity encoder: estimates how "complex" each token representation is
222
+ self.complexity_encoder = nn.Sequential(
223
+ nn.Linear(dim, dim // 4),
224
+ nn.GELU(),
225
+ nn.Linear(dim // 4, 1),
226
+ nn.Sigmoid() # Output: 0 (simple) to 1 (complex)
227
+ )
228
+
229
+ # Adaptive temperature: dynamically adjusts expert selection based on complexity
230
+ self.complexity_proj = nn.Linear(dim, 1)
231
+
232
+ # Learnable bias for complexity-aware routing
233
+ self.complexity_bias = nn.Parameter(torch.zeros(1))
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ batch_size, seq_len, dim = x.shape
237
+
238
+ # Flatten for expert routing
239
+ x_flat = x.reshape(-1, dim)
240
+ num_tokens = x_flat.shape[0]
241
+
242
+ # Compute standard gate scores
243
+ gate_scores = self.gate(x_flat)
244
+
245
+ # NOVEL: Adaptive routing based on token complexity
246
+ if self.adaptive_routing:
247
+ # Estimate complexity of each token (0 = simple, 1 = complex)
248
+ complexity_scores = self.complexity_encoder(x_flat) # [num_tokens, 1]
249
+
250
+ # Compute adaptive temperature based on complexity
251
+ # Complex tokens get lower temperature (sharper distribution)
252
+ # Simple tokens get higher temperature (softer distribution)
253
+ complexity_temp = self.complexity_proj(x_flat) + self.complexity_bias
254
+ # Temperature in range [0.5, 2.0] - inverse relationship with complexity
255
+ adaptive_temp = 0.5 + 1.5 * (1.0 - complexity_scores.squeeze(-1))
256
+
257
+ # Apply adaptive temperature scaling to gate scores
258
+ # Lower temp = sharper = focus on fewer experts
259
+ # Higher temp = softer = distribute more evenly
260
+ scaled_scores = gate_scores / (adaptive_temp.unsqueeze(-1) + 1e-8)
261
+
262
+ if self.noisy_gating and self.training:
263
+ # Reduced noise for complex tokens (they should be more confident)
264
+ noise_scale = (1.0 / self.num_experts) * (1.0 - complexity_scores.squeeze(-1) * 0.5)
265
+ noise = torch.randn_like(gate_scores) * noise_scale.unsqueeze(-1)
266
+ scaled_scores = scaled_scores + noise
267
+ else:
268
+ scaled_scores = gate_scores
269
+ if self.noisy_gating and self.training:
270
+ noise = torch.randn_like(gate_scores) * (1.0 / self.num_experts)
271
+ scaled_scores = scaled_scores + noise
272
+
273
+ # Get top-2 experts using adaptive scores
274
+ top_k = 2
275
+ top_scores, top_indices = torch.topk(scaled_scores, k=top_k, dim=-1)
276
+ top_gates = F.softmax(top_scores, dim=-1, dtype=torch.float32).to(x_flat.dtype)
277
+
278
+ # Create placeholder for final output
279
+ final_output = torch.zeros_like(x_flat)
280
+
281
+ # Compute auxiliary loss for load balancing (use original gate_scores, not scaled)
282
+ self.aux_loss = self._load_balancing_loss(gate_scores, top_indices)
283
+
284
+ # Route tokens to experts
285
+ for i in range(top_k):
286
+ # Process tokens for the i-th choice expert
287
+ expert_indices = top_indices[:, i]
288
+ gate_values = top_gates[:, i].unsqueeze(-1)
289
+
290
+ for expert_idx, expert in enumerate(self.experts):
291
+ token_indices = (expert_indices == expert_idx).nonzero(as_tuple=True)[0]
292
+
293
+ if token_indices.numel() > 0:
294
+ selected_tokens = x_flat[token_indices]
295
+ selected_gates = gate_values[token_indices]
296
+
297
+ expert_output = expert(selected_tokens)
298
+ final_output.index_add_(0, token_indices, expert_output * selected_gates)
299
+
300
+ # Reshape back to original dimensions
301
+ return final_output.reshape(batch_size, seq_len, dim)
302
+
303
+ def _load_balancing_loss(self, gate_scores: torch.Tensor, top_indices: torch.Tensor) -> torch.Tensor:
304
+ """Compute load balancing auxiliary loss"""
305
+ if not self.training:
306
+ return torch.tensor(0.0, device=gate_scores.device)
307
+
308
+ # Compute fraction of tokens routed to each expert (based on top-1 choice)
309
+ top1_indices = top_indices[:, 0]
310
+ expert_mask = F.one_hot(top1_indices, num_classes=self.num_experts).float()
311
+ routing_fraction = expert_mask.mean(dim=0)
312
+
313
+ # Compute fraction of gate probability for each expert
314
+ gate_prob = F.softmax(gate_scores, dim=-1)
315
+ gate_fraction = gate_prob.mean(dim=0)
316
+
317
+ # Load balancing loss
318
+ load_balance_loss = self.num_experts * torch.sum(routing_fraction * gate_fraction)
319
+
320
+ return load_balance_loss
321
+
322
+
323
+ class SlimMoETransformerBlock(nn.Module):
324
+ """Transformer block with VGQA and MOE"""
325
+
326
+ def __init__(self, dim: int, num_heads: int, hidden_dim: int,
327
+ num_experts: int = 4, dropout: float = 0.1,
328
+ layer_idx: int = 0, num_layers: int = 12,
329
+ adaptive_routing: bool = True):
330
+ super().__init__()
331
+ self.dim = dim
332
+ self.adaptive_routing = adaptive_routing
333
+
334
+ # Attention components with layer-specific KV heads
335
+ self.attn_norm = nn.LayerNorm(dim)
336
+
337
+ # NoPE every 4th layer (layers 3, 7, 11, ...), RoPE for all others
338
+ # Pattern: layer_idx % 4 == 3 means it's the 4th layer (0-indexed: 3rd, 7th, etc.)
339
+ use_rope = (layer_idx % 4 != 3)
340
+
341
+ self.attention = VariableGroupedQueryAttention(
342
+ dim, num_heads, layer_idx=layer_idx,
343
+ num_layers=num_layers, variable_groups=True,
344
+ use_rope=use_rope
345
+ )
346
+
347
+ # Dense transformer feed-forward (before MoE)
348
+ self.dense_ffn_norm = nn.LayerNorm(dim)
349
+ self.dense_ffn = nn.Sequential(
350
+ nn.Linear(dim, hidden_dim),
351
+ nn.GELU(),
352
+ nn.Dropout(dropout),
353
+ nn.Linear(hidden_dim, dim),
354
+ nn.Dropout(dropout)
355
+ )
356
+
357
+ # MOE components
358
+ self.moe_norm = nn.LayerNorm(dim)
359
+ self.moe = MOELayer(dim, hidden_dim, num_experts, adaptive_routing=adaptive_routing)
360
+
361
+ # Dropout
362
+ self.dropout = nn.Dropout(dropout)
363
+
364
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
365
+ # Attention branch with residual
366
+ attn_norm_out = self.attn_norm(x)
367
+ attn_out = self.attention(attn_norm_out, attention_mask)
368
+ x = x + self.dropout(attn_out)
369
+
370
+ # Dense transformer feed-forward branch with residual
371
+ dense_ffn_norm_out = self.dense_ffn_norm(x)
372
+ dense_ffn_out = self.dense_ffn(dense_ffn_norm_out)
373
+ x = x + dense_ffn_out
374
+
375
+ # MOE branch with residual
376
+ moe_norm_out = self.moe_norm(x)
377
+ moe_out = self.moe(moe_norm_out)
378
+ x = x + self.dropout(moe_out)
379
+
380
+ return x
381
+
382
+
383
+ class SlimMOETransformer(nn.Module):
384
+ """Complete MOE Transformer with Variable Grouped Query Attention and RoPE"""
385
+
386
+ def __init__(self, vocab_size: int = 50257, dim: int = 768, num_layers: int = 12,
387
+ num_heads: int = 12, hidden_dim: int = 2048, num_experts: int = 4,
388
+ max_seq_len: int = 2048, dropout: float = 0.1, adaptive_routing: bool = True):
389
+ super().__init__()
390
+
391
+ self.vocab_size = vocab_size
392
+ self.dim = dim
393
+ self.num_layers = num_layers
394
+ self.max_seq_len = max_seq_len
395
+
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.dropout = nn.Dropout(dropout)
398
+ self.layers = nn.ModuleList([
399
+ SlimMoETransformerBlock(
400
+ dim=dim,
401
+ num_heads=num_heads,
402
+ hidden_dim=hidden_dim,
403
+ num_experts=num_experts,
404
+ dropout=dropout,
405
+ layer_idx=i,
406
+ num_layers=num_layers,
407
+ adaptive_routing=adaptive_routing
408
+ ) for i in range(num_layers)
409
+ ])
410
+ self.norm = nn.LayerNorm(dim)
411
+
412
+ # --- FIX: Remove the lm_head from the base transformer model ---
413
+ # self.lm_head = nn.Linear(dim, vocab_size, bias=False)
414
+ # The CausalLM wrapper will handle the final projection.
415
+
416
+ self.apply(self._init_weights)
417
+ self._calculate_parameters() # This will now show a smaller number
418
+
419
+ def _init_weights(self, module):
420
+ """Initialize weights"""
421
+ if isinstance(module, nn.Linear):
422
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
423
+ if module.bias is not None:
424
+ torch.nn.init.zeros_(module.bias)
425
+ elif isinstance(module, nn.Embedding):
426
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
427
+ elif isinstance(module, nn.LayerNorm):
428
+ torch.nn.init.zeros_(module.bias)
429
+ torch.nn.init.ones_(module.weight)
430
+
431
+ def _calculate_parameters(self):
432
+ # ... (this method is unchanged) ...
433
+ total_params = sum(p.numel() for p in self.parameters())
434
+ print(f"Total Parameters: {total_params:,}")
435
+
436
+ def forward(self, input_ids: torch.Tensor,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ labels: Optional[torch.Tensor] = None) -> dict: # Note: labels are ignored here now
439
+
440
+ batch_size, seq_len = input_ids.shape
441
+
442
+ causal_mask = torch.triu(
443
+ torch.full((1, 1, seq_len, seq_len), -torch.finfo(torch.get_default_dtype()).max, device=input_ids.device),
444
+ diagonal=1
445
+ )
446
+
447
+ if attention_mask is not None:
448
+ padding_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -torch.finfo(
449
+ torch.get_default_dtype()).max
450
+ extended_attention_mask = causal_mask + padding_mask
451
+ else:
452
+ extended_attention_mask = causal_mask
453
+
454
+ x = self.token_embedding(input_ids) * math.sqrt(self.dim)
455
+ x = self.dropout(x)
456
+
457
+ total_aux_loss = 0.0
458
+ for layer in self.layers:
459
+ x = layer(x, extended_attention_mask)
460
+ if self.training:
461
+ total_aux_loss += layer.moe.aux_loss
462
+
463
+ x = self.norm(x)
464
+
465
+ # --- FIX: Return hidden states and aux loss, not logits ---
466
+ return {
467
+ 'last_hidden_state': x,
468
+ 'aux_loss': total_aux_loss
469
+ }
470
+
471
+
472
+ def create_moe_model(vocab_size: int = 50257) -> SlimMOETransformer:
473
+ """
474
+ Create a MOE model with approximately 300M parameters.
475
+
476
+ Configuration:
477
+ - dim=768, num_layers=16, num_heads=12
478
+ - hidden_dim=1536, num_experts=4
479
+ - This yields ~280-290M parameters, safely under 300M
480
+ """
481
+ return SlimMOETransformer(
482
+ vocab_size=vocab_size,
483
+ dim=768,
484
+ num_layers=16,
485
+ num_heads=12,
486
+ hidden_dim=1536,
487
+ num_experts=4,
488
+ max_seq_len=2048,
489
+ dropout=0.1
490
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<bos>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<eos>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "mask_token": {
17
+ "content": "<mask>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "pad_token": {
24
+ "content": "<pad>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "<unk>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04f0f8df41eb4be0afb1b647ce8d2fb9c3d4bec06eeb3c2f33ca6a8cf1e88f79
3
+ size 247574399
tokenizer_config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<bos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<eos>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<bos>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "eos_token": "<eos>",
47
+ "extra_special_tokens": {},
48
+ "mask_token": "<mask>",
49
+ "max_length": 2048,
50
+ "model_max_length": 2048,
51
+ "pad_to_multiple_of": null,
52
+ "pad_token": "<pad>",
53
+ "pad_token_type_id": 0,
54
+ "padding_side": "right",
55
+ "stride": 0,
56
+ "tokenizer_class": "PreTrainedTokenizerFast",
57
+ "truncation_side": "right",
58
+ "truncation_strategy": "longest_first",
59
+ "unk_token": "<unk>"
60
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15f83cb43be7eb54abe0f8888da955fddedcf18793acdd4eaec1ce9d35581e01
3
+ size 5816