andrecornman commited on
Commit
45bc549
·
verified ·
1 Parent(s): 1b5c960

fix init_weights

Browse files
Files changed (1) hide show
  1. modeling_glm2.py +19 -3
modeling_glm2.py CHANGED
@@ -353,7 +353,7 @@ class gLM2PreTrainedModel(PreTrainedModel):
353
  supports_gradient_checkpointing = False
354
 
355
  # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
356
- def _init_weights(module, initializer_range=0.02):
357
  if isinstance(module, nn.Linear):
358
  nn.init.normal_(module.weight, std=initializer_range)
359
  if module.bias is not None:
@@ -362,7 +362,22 @@ class gLM2PreTrainedModel(PreTrainedModel):
362
  nn.init.normal_(module.weight, std=initializer_range)
363
  if module.padding_idx is not None:
364
  nn.init.zeros_(module.weight[module.padding_idx])
365
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  class gLM2Model(gLM2PreTrainedModel):
368
  """gLM2 Model."""
@@ -438,6 +453,7 @@ class gLM2ForEmbedding(gLM2PreTrainedModel):
438
  self.glm2 = gLM2Model(config)
439
  self.pool = MeanPooling()
440
  self.projection = nn.Linear(config.dim, config.projection_dim, bias=False)
 
441
 
442
  def forward(
443
  self,
@@ -466,7 +482,7 @@ class gLM2ForMaskedLM(gLM2PreTrainedModel):
466
 
467
  self.glm2 = gLM2Model(config)
468
  self.lm_head = gLM2LMHead(config)
469
- self.init_weights()
470
 
471
  def forward(
472
  self,
 
353
  supports_gradient_checkpointing = False
354
 
355
  # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
356
+ def _init_weights(self, module, initializer_range=0.02):
357
  if isinstance(module, nn.Linear):
358
  nn.init.normal_(module.weight, std=initializer_range)
359
  if module.bias is not None:
 
362
  nn.init.normal_(module.weight, std=initializer_range)
363
  if module.padding_idx is not None:
364
  nn.init.zeros_(module.weight[module.padding_idx])
365
+ elif isinstance(module, RotaryEmbedding):
366
+ # Re-calculate the frequencies using the module's stored attributes
367
+ inv_freq = 1.0 / (
368
+ module.base
369
+ ** (
370
+ torch.arange(0, module.dim, 2, device=module.inv_freq.device, dtype=torch.float32)
371
+ / module.dim
372
+ )
373
+ )
374
+ # Force the buffer to update
375
+ with torch.no_grad():
376
+ module.inv_freq.copy_(inv_freq)
377
+ elif isinstance(module, RMSNorm):
378
+ if hasattr(module, "variance_epsilon"):
379
+ with torch.no_grad():
380
+ module.variance_epsilon.fill_(self.config.norm_eps)
381
 
382
  class gLM2Model(gLM2PreTrainedModel):
383
  """gLM2 Model."""
 
453
  self.glm2 = gLM2Model(config)
454
  self.pool = MeanPooling()
455
  self.projection = nn.Linear(config.dim, config.projection_dim, bias=False)
456
+ self.post_init()
457
 
458
  def forward(
459
  self,
 
482
 
483
  self.glm2 = gLM2Model(config)
484
  self.lm_head = gLM2LMHead(config)
485
+ self.post_init()
486
 
487
  def forward(
488
  self,