empero-ai commited on
Commit
543064e
·
verified ·
1 Parent(s): af052ef

Mask BEFORE exp to avoid inf*0=NaN in bf16

Browse files
Files changed (1) hide show
  1. modeling_nemotron_h.py +10 -6
modeling_nemotron_h.py CHANGED
@@ -482,9 +482,11 @@ class NemotronHMamba2Mixer(nn.Module):
482
  # Decay matrix L via cumsum difference — replaces segment_sum which
483
  # expanded [chunk] → [chunk, chunk] via O(n^2) broadcast.
484
  # Math: L[i,j] = exp(A_cumsum[i] - A_cumsum[j]) for j <= i
485
- L = torch.exp(A_cumsum[..., :, None] - A_cumsum[..., None, :])
486
- L = L * torch.tril(torch.ones(
487
- self.chunk_size, self.chunk_size, device=L.device, dtype=L.dtype))
 
 
488
 
489
  # Contract ssm_state via einsum FIRST — avoids materializing the
490
  # [chunk, chunk, heads, state] outer product (was 68GB in fp32).
@@ -510,11 +512,13 @@ class NemotronHMamba2Mixer(nn.Module):
510
  states = torch.cat([previous_states, states], dim=1)
511
 
512
  # Inter-chunk decay via cumsum difference (n_chunks is small, ~16)
 
513
  chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1)
514
  n_plus1 = chunk_cumA.shape[-1]
515
- decay_chunk = torch.exp(chunk_cumA[..., :, None] - chunk_cumA[..., None, :])
516
- decay_chunk = decay_chunk * torch.tril(torch.ones(
517
- n_plus1, n_plus1, device=decay_chunk.device, dtype=decay_chunk.dtype))
 
518
  decay_chunk = decay_chunk.transpose(1, 3)
519
 
520
  # Contract n_chunks+1 via einsum
 
482
  # Decay matrix L via cumsum difference — replaces segment_sum which
483
  # expanded [chunk] → [chunk, chunk] via O(n^2) broadcast.
484
  # Math: L[i,j] = exp(A_cumsum[i] - A_cumsum[j]) for j <= i
485
+ # Mask BEFORE exp to avoid inf*0=NaN in bf16 (upper triangle overflows)
486
+ L_arg = A_cumsum[..., :, None] - A_cumsum[..., None, :]
487
+ causal_mask = torch.tril(torch.ones(
488
+ self.chunk_size, self.chunk_size, device=L_arg.device, dtype=torch.bool))
489
+ L = torch.exp(L_arg.masked_fill(~causal_mask, float('-inf')))
490
 
491
  # Contract ssm_state via einsum FIRST — avoids materializing the
492
  # [chunk, chunk, heads, state] outer product (was 68GB in fp32).
 
512
  states = torch.cat([previous_states, states], dim=1)
513
 
514
  # Inter-chunk decay via cumsum difference (n_chunks is small, ~16)
515
+ # Mask BEFORE exp to avoid inf*0=NaN in bf16
516
  chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1)
517
  n_plus1 = chunk_cumA.shape[-1]
518
+ decay_arg = chunk_cumA[..., :, None] - chunk_cumA[..., None, :]
519
+ chunk_mask = torch.tril(torch.ones(
520
+ n_plus1, n_plus1, device=decay_arg.device, dtype=torch.bool))
521
+ decay_chunk = torch.exp(decay_arg.masked_fill(~chunk_mask, float('-inf')))
522
  decay_chunk = decay_chunk.transpose(1, 3)
523
 
524
  # Contract n_chunks+1 via einsum