Johnblick187 commited on
Commit
9eb4e00
·
verified ·
1 Parent(s): c8a1ec0

Upload modeling_smartcoder_moe.py

Browse files
Files changed (1) hide show
  1. modeling_smartcoder_moe.py +27 -20
modeling_smartcoder_moe.py CHANGED
@@ -168,6 +168,29 @@ class SmartCoderMoEMLP(nn.Module):
168
  self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
169
  self.router = nn.Linear(H, NE, bias=False)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def forward(self, x):
172
  B, T, H = x.shape
173
 
@@ -254,14 +277,6 @@ class SmartCoderMoEForCausalLM(PreTrainedModel, GenerationMixin):
254
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
255
  self.post_init()
256
 
257
- def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
258
- remapped = {}
259
- for k, v in state_dict.items():
260
- k = k.replace('experts_fc.weight', 'experts_fc')
261
- k = k.replace('experts_proj.weight', 'experts_proj')
262
- remapped[k] = v
263
- super()._load_from_state_dict(remapped, prefix, *args, **kwargs)
264
-
265
  def get_input_embeddings(self): return self.model.embed_tokens
266
  def get_output_embeddings(self): return self.lm_head
267
 
@@ -328,17 +343,9 @@ def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloa
328
  for f in sf_files:
329
  state_dict.update(load_file(str(f)))
330
 
331
- # Remap expert keys — safetensors has .weight suffix, our params don't
332
- remapped = {}
333
- for k, v in state_dict.items():
334
- if 'experts_fc.weight' in k:
335
- remapped[k.replace('experts_fc.weight', 'experts_fc')] = v
336
- elif 'experts_proj.weight' in k:
337
- remapped[k.replace('experts_proj.weight', 'experts_proj')] = v
338
- else:
339
- remapped[k] = v
340
- state_dict = remapped
341
-
342
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
343
  if missing:
344
  print(f"Missing: {missing[:3]}{'...' if len(missing)>3 else ''}")
@@ -351,4 +358,4 @@ def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloa
351
 
352
  from transformers import AutoConfig, AutoModelForCausalLM
353
  AutoConfig.register("smartcoder_moe", SmartCoderMoEConfig)
354
- AutoModelForCausalLM.register(SmartCoderMoEConfig, SmartCoderMoEForCausalLM)
 
168
  self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
169
  self.router = nn.Linear(H, NE, bias=False)
170
 
171
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
172
+ # Checkpoint stores expert weights with a '.weight' suffix (as if
173
+ # experts_fc/experts_proj were nn.Linear submodules), but they're
174
+ # raw nn.Parameter tensors here (no '.weight' child -- needed for
175
+ # batched bmm across all experts at once, see forward() below).
176
+ # PyTorch's load_state_dict() recursion calls _load_from_state_dict
177
+ # on EVERY submodule in the tree directly (using each module's own
178
+ # class method, not a parent class's override) -- so the remap has
179
+ # to live on THIS class, not on SmartCoderMoEForCausalLM. The
180
+ # previous override sat on the top-level CausalLM class and only
181
+ # ever fired for its own direct params/buffers (it has none), never
182
+ # for this module's recursive call -- silently skipping every
183
+ # expert tensor. That's the actual bug.
184
+ remapped = {}
185
+ for k, v in state_dict.items():
186
+ if k == prefix + "experts_fc.weight":
187
+ remapped[prefix + "experts_fc"] = v
188
+ elif k == prefix + "experts_proj.weight":
189
+ remapped[prefix + "experts_proj"] = v
190
+ else:
191
+ remapped[k] = v
192
+ super()._load_from_state_dict(remapped, prefix, *args, **kwargs)
193
+
194
  def forward(self, x):
195
  B, T, H = x.shape
196
 
 
277
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
278
  self.post_init()
279
 
 
 
 
 
 
 
 
 
280
  def get_input_embeddings(self): return self.model.embed_tokens
281
  def get_output_embeddings(self): return self.lm_head
282
 
 
343
  for f in sf_files:
344
  state_dict.update(load_file(str(f)))
345
 
346
+ # Expert key remap (.weight suffix in the checkpoint vs raw Parameter
347
+ # here) is now handled by SmartCoderMoEMLP._load_from_state_dict
348
+ # itself, so load_state_dict() needs no manual remapping here anymore.
 
 
 
 
 
 
 
 
349
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
350
  if missing:
351
  print(f"Missing: {missing[:3]}{'...' if len(missing)>3 else ''}")
 
358
 
359
  from transformers import AutoConfig, AutoModelForCausalLM
360
  AutoConfig.register("smartcoder_moe", SmartCoderMoEConfig)
361
+ AutoModelForCausalLM.register(SmartCoderMoEConfig, SmartCoderMoEForCausalLM)