Mask BEFORE exp to avoid inf*0=NaN in bf16
Browse files- modeling_nemotron_h.py +10 -6
modeling_nemotron_h.py
CHANGED
|
@@ -482,9 +482,11 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 482 |
# Decay matrix L via cumsum difference — replaces segment_sum which
|
| 483 |
# expanded [chunk] → [chunk, chunk] via O(n^2) broadcast.
|
| 484 |
# Math: L[i,j] = exp(A_cumsum[i] - A_cumsum[j]) for j <= i
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
| 488 |
|
| 489 |
# Contract ssm_state via einsum FIRST — avoids materializing the
|
| 490 |
# [chunk, chunk, heads, state] outer product (was 68GB in fp32).
|
|
@@ -510,11 +512,13 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 510 |
states = torch.cat([previous_states, states], dim=1)
|
| 511 |
|
| 512 |
# Inter-chunk decay via cumsum difference (n_chunks is small, ~16)
|
|
|
|
| 513 |
chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1)
|
| 514 |
n_plus1 = chunk_cumA.shape[-1]
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
n_plus1, n_plus1, device=
|
|
|
|
| 518 |
decay_chunk = decay_chunk.transpose(1, 3)
|
| 519 |
|
| 520 |
# Contract n_chunks+1 via einsum
|
|
|
|
| 482 |
# Decay matrix L via cumsum difference — replaces segment_sum which
|
| 483 |
# expanded [chunk] → [chunk, chunk] via O(n^2) broadcast.
|
| 484 |
# Math: L[i,j] = exp(A_cumsum[i] - A_cumsum[j]) for j <= i
|
| 485 |
+
# Mask BEFORE exp to avoid inf*0=NaN in bf16 (upper triangle overflows)
|
| 486 |
+
L_arg = A_cumsum[..., :, None] - A_cumsum[..., None, :]
|
| 487 |
+
causal_mask = torch.tril(torch.ones(
|
| 488 |
+
self.chunk_size, self.chunk_size, device=L_arg.device, dtype=torch.bool))
|
| 489 |
+
L = torch.exp(L_arg.masked_fill(~causal_mask, float('-inf')))
|
| 490 |
|
| 491 |
# Contract ssm_state via einsum FIRST — avoids materializing the
|
| 492 |
# [chunk, chunk, heads, state] outer product (was 68GB in fp32).
|
|
|
|
| 512 |
states = torch.cat([previous_states, states], dim=1)
|
| 513 |
|
| 514 |
# Inter-chunk decay via cumsum difference (n_chunks is small, ~16)
|
| 515 |
+
# Mask BEFORE exp to avoid inf*0=NaN in bf16
|
| 516 |
chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1)
|
| 517 |
n_plus1 = chunk_cumA.shape[-1]
|
| 518 |
+
decay_arg = chunk_cumA[..., :, None] - chunk_cumA[..., None, :]
|
| 519 |
+
chunk_mask = torch.tril(torch.ones(
|
| 520 |
+
n_plus1, n_plus1, device=decay_arg.device, dtype=torch.bool))
|
| 521 |
+
decay_chunk = torch.exp(decay_arg.masked_fill(~chunk_mask, float('-inf')))
|
| 522 |
decay_chunk = decay_chunk.transpose(1, 3)
|
| 523 |
|
| 524 |
# Contract n_chunks+1 via einsum
|