Update modeling_minimythos_hybrid.py
Browse files
modeling_minimythos_hybrid.py
CHANGED
|
@@ -78,7 +78,12 @@ class RMSNorm(nn.Module):
|
|
| 78 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 79 |
|
| 80 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class ReservoirBlock(nn.Module):
|
|
@@ -270,6 +275,9 @@ class MiniMythosHybridForCausalLM(PreTrainedModel):
|
|
| 270 |
hidden_states.append(x)
|
| 271 |
|
| 272 |
logits = self.lm_head(x)
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
loss = None
|
| 275 |
if labels is not None:
|
|
|
|
| 78 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 79 |
|
| 80 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
# Compute RMSNorm in fp32 for numerical stability, then cast back.
|
| 82 |
+
orig_dtype = x.dtype
|
| 83 |
+
x_float = x.float()
|
| 84 |
+
var = x_float.pow(2).mean(-1, keepdim=True)
|
| 85 |
+
x_norm = x_float * torch.rsqrt(var + self.eps)
|
| 86 |
+
return (self.weight.float() * x_norm).to(orig_dtype)
|
| 87 |
|
| 88 |
|
| 89 |
class ReservoirBlock(nn.Module):
|
|
|
|
| 275 |
hidden_states.append(x)
|
| 276 |
|
| 277 |
logits = self.lm_head(x)
|
| 278 |
+
# Prevent generation from crashing if a checkpoint contains unstable values.
|
| 279 |
+
# This should not hide training issues, but it makes inference robust.
|
| 280 |
+
logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
|
| 281 |
|
| 282 |
loss = None
|
| 283 |
if labels is not None:
|