summerstars commited on
Commit
dde2635
·
verified ·
1 Parent(s): 808826f

Update modeling_minimythos_hybrid.py

Browse files
Files changed (1) hide show
  1. modeling_minimythos_hybrid.py +9 -1
modeling_minimythos_hybrid.py CHANGED
@@ -78,7 +78,12 @@ class RMSNorm(nn.Module):
78
  self.weight = nn.Parameter(torch.ones(dim))
79
 
80
  def forward(self, x: torch.Tensor) -> torch.Tensor:
81
- return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
 
 
 
 
82
 
83
 
84
  class ReservoirBlock(nn.Module):
@@ -270,6 +275,9 @@ class MiniMythosHybridForCausalLM(PreTrainedModel):
270
  hidden_states.append(x)
271
 
272
  logits = self.lm_head(x)
 
 
 
273
 
274
  loss = None
275
  if labels is not None:
 
78
  self.weight = nn.Parameter(torch.ones(dim))
79
 
80
  def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ # Compute RMSNorm in fp32 for numerical stability, then cast back.
82
+ orig_dtype = x.dtype
83
+ x_float = x.float()
84
+ var = x_float.pow(2).mean(-1, keepdim=True)
85
+ x_norm = x_float * torch.rsqrt(var + self.eps)
86
+ return (self.weight.float() * x_norm).to(orig_dtype)
87
 
88
 
89
  class ReservoirBlock(nn.Module):
 
275
  hidden_states.append(x)
276
 
277
  logits = self.lm_head(x)
278
+ # Prevent generation from crashing if a checkpoint contains unstable values.
279
+ # This should not hide training issues, but it makes inference robust.
280
+ logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
281
 
282
  loss = None
283
  if labels is not None: