openNemo-9B / modeling_nemotron_h.py
empero-ai's picture
Mask BEFORE exp to avoid inf*0=NaN in bf16
543064e verified
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Modified 2026 by L. Lehmann (kodee2k) @ Empero AI (https://empero.org)
# openNemo: Pure-PyTorch Mamba2 (no mamba-ssm / causal-conv1d dependency)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
openNemo — Pure-PyTorch NemotronH model.
Drop-in replacement for nvidia's modeling_nemotron_h.py that removes ALL
external CUDA kernel dependencies (mamba-ssm, causal-conv1d). This makes
the model fully compatible with bitsandbytes quantization (4-bit / 8-bit)
and trainable on consumer GPUs with QLoRA.
Changes from original:
- Removed: mamba_ssm imports (selective_state_update, ssd_combined, rmsnorm_fn)
- Removed: causal_conv1d imports (causal_conv1d_fn, causal_conv1d_update)
- Rewrote: MambaRMSNormGated → pure PyTorch (no rmsnorm_fn)
- Rewrote: NemotronHMamba2Mixer.cuda_kernels_forward → removed entirely
- Rewrote: NemotronHMamba2Mixer.torch_forward → optimized chunked scan
- Rewrote: forward() routing → always uses torch_forward (no fast_path check)
- Added: causal_conv1d_naive for causal 1D convolution
- All weights are binary-compatible — load original checkpoints directly.
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
from .configuration_nemotron_h import NemotronHConfig
logger = logging.get_logger(__name__)
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K"
_CONFIG_FOR_DOC = "NemotronHConfig"
# ─── Pure-PyTorch Helpers (replace mamba-ssm / causal-conv1d) ─────────────────
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
"""Padding x tensor with `pad_size` on the seq_len dim (dim=1)."""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return F.pad(input_tensor, pad_shape, mode="constant", value=0)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""Pad and reshape into chunks along seq_len dim."""
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
)
def segment_sum(input_tensor):
"""Stable segment sum via cumulative sums and masking."""
chunk_size = input_tensor.size(-1)
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""Zero out hidden states for padding tokens."""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
def causal_conv1d_naive(x, weight, bias=None, activation="silu"):
"""
Pure-PyTorch causal 1D depthwise convolution.
x: (batch, channels, seq_len)
weight: (channels, kernel_size)
bias: (channels,) or None
Returns: (batch, channels, seq_len)
"""
channels, kernel_size = weight.shape
# Causal padding: pad left only
x_padded = F.pad(x, (kernel_size - 1, 0))
# Depthwise conv: groups=channels
weight_conv = weight.unsqueeze(1) # (channels, 1, kernel_size)
out = F.conv1d(x_padded, weight_conv, bias=bias, groups=channels)
if activation in ("silu", "swish"):
out = F.silu(out)
return out
def rms_norm_gated(hidden_states, weight, gate=None, eps=1e-5, group_size=None):
"""
Pure-PyTorch gated RMSNorm — replaces mamba_ssm's rmsnorm_fn.
norm_before_gate=False (matching NVIDIA's original): gate first, then normalize.
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
# Gate FIRST (norm_before_gate=False)
if gate is not None:
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
if group_size is not None and group_size < hidden_states.shape[-1]:
# Group-wise RMSNorm
orig_shape = hidden_states.shape
hidden_states = hidden_states.reshape(*orig_shape[:-1], -1, group_size)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = hidden_states.reshape(orig_shape)
else:
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
# ─── Cache ────────────────────────────────────────────────────────────────────
class HybridMambaAttentionDynamicCache(DynamicCache):
"""
Cache for hybrid Mamba-Attention model. Handles both attention KV cache
and Mamba conv/SSM state cache.
"""
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
self.hybrid_override_pattern = config.hybrid_override_pattern
self.has_previous_state = False
intermediate_size = config.mamba_num_heads * config.mamba_head_dim
ssm_state_size = config.ssm_state_size
conv_kernel_size = config.conv_kernel
self.conv_kernel_size = conv_kernel_size
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
for i in range(config.num_hidden_layers):
if self.hybrid_override_pattern[i] == "M":
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
]
self.ssm_states += [
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
]
else:
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
self.transformer_layers.append(i)
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
if self.key_cache[layer_idx].shape[-1] == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx):
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.conv_states[layer_idx].device
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx=0):
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def to_legacy_cache(self):
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values=None):
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
def update_conv_state(self, layer_idx, new_conv_state, cache_init=False):
if cache_init:
self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device)
else:
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx, new_ssm_state):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
for s in self.conv_states:
if s.numel() > 0:
s.zero_()
for s in self.ssm_states:
if s.numel() > 0:
s.zero_()
# ─── Gated RMSNorm (pure PyTorch) ────────────────────────────────────────────
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, group_size=None, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size if group_size is not None else hidden_size
def forward(self, hidden_states, gate=None):
return rms_norm_gated(
hidden_states,
self.weight,
gate=gate,
eps=self.variance_epsilon,
group_size=self.group_size,
)
# ─── Mamba2 Mixer (pure PyTorch — no mamba-ssm) ──────────────────────────────
class NemotronHMamba2Mixer(nn.Module):
"""
Pure-PyTorch Mamba2 SSM mixer. Weight-compatible with the original
NVIDIA implementation but uses no external CUDA kernels.
"""
def __init__(self, config: NemotronHConfig, layer_idx: int):
super().__init__()
self.num_heads = config.mamba_num_heads
self.hidden_size = config.hidden_size
self.ssm_state_size = config.ssm_state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.activation = config.mamba_hidden_act
self.act = ACT2FN[config.mamba_hidden_act]
self.layer_norm_epsilon = config.layer_norm_epsilon
self.n_groups = config.n_groups
self.head_dim = config.mamba_head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.conv_dim,
padding=config.conv_kernel - 1,
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=config.use_bias)
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon,
group_size=self.intermediate_size // self.n_groups,
)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
@property
def o_proj(self):
"""Alias for tooling that expects o_proj."""
return self.out_proj
def _single_step_forward(
self,
hidden_states,
cache_params,
attention_mask=None,
):
"""Single token generation step with cache."""
batch_size = hidden_states.shape[0]
groups_time_state_size = self.n_groups * self.ssm_state_size
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states)
d_mlp = (
projected_states.shape[-1]
- 2 * self.intermediate_size
- 2 * self.n_groups * self.ssm_state_size
- self.num_heads
) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# Conv update (single step)
conv_state = cache_params.conv_states[self.layer_idx]
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, -1] = hidden_states_B_C
cache_params.conv_states[self.layer_idx] = conv_state
hidden_states_B_C = torch.sum(
conv_state * self.conv1d.weight.squeeze(1), dim=-1
)
if self.use_conv_bias:
hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
hidden_states_B_C = self.act(hidden_states_B_C)
hidden_states_inner, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
# SSM step
A = -torch.exp(self.A_log.float())
A = A[:, None, ...].expand(-1, self.head_dim)[:, :, None].expand(-1, -1, self.ssm_state_size).to(torch.float32)
dt_expanded = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None].expand(-1, self.head_dim)
D = self.D[:, None].expand(-1, self.head_dim)
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
hidden_reshaped = hidden_states_inner.view(batch_size, self.num_heads, self.head_dim)
# Selective state update (pure PyTorch)
dt_with_bias = F.softplus(dt_expanded + dt_bias)
dt_with_bias = torch.clamp(dt_with_bias, self.time_step_limit[0], self.time_step_limit[1])
dA = torch.exp(dt_with_bias.unsqueeze(-1) * A.unsqueeze(0))
# Expand B for heads
B_expanded = B[:, :, None, :].expand(-1, -1, self.num_heads // self.n_groups, -1)
B_expanded = B_expanded.reshape(batch_size, self.num_heads, self.ssm_state_size)
dBx = dt_with_bias.unsqueeze(-1) * B_expanded.unsqueeze(2) * hidden_reshaped.unsqueeze(-1)
ssm_state = cache_params.ssm_states[self.layer_idx]
ssm_state = ssm_state.to(dA.device, dtype=dA.dtype)
new_ssm_state = ssm_state * dA + dBx
cache_params.ssm_states[self.layer_idx] = new_ssm_state.to(cache_params.ssm_states[self.layer_idx].dtype)
# Output
C_expanded = C[:, :, None, :].expand(-1, -1, self.num_heads // self.n_groups, -1)
C_expanded = C_expanded.reshape(batch_size, self.num_heads, self.ssm_state_size)
y = (new_ssm_state.to(C_expanded.dtype) * C_expanded.unsqueeze(2)).sum(-1)
y = y + hidden_reshaped * D.unsqueeze(0)
y = y.reshape(batch_size, -1)
y = self.norm(y, gate)
out = self.out_proj(y)[:, None, ...]
return out
# fmt: off
def _chunked_forward(self, input_states, cache_params=None, attention_mask=None):
"""
Full sequence forward pass using chunked SSD scan.
This is the torch_forward from the original, which works correctly
with bitsandbytes quantization.
"""
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
groups_time_state_size = self.n_groups * self.ssm_state_size
# 1. Project
input_states = apply_mask_to_padding_states(input_states, attention_mask)
projected_states = self.in_proj(input_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# 2. Convolution
if cache_params is not None:
hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
conv_states = F.pad(
hidden_states_B_C_transposed,
(self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
)
cache_params.update_conv_state(
layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
)
# Use pure-PyTorch causal conv1d
hidden_states_B_C = self.act(
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
# 3. SSM (chunked SSD scan — memory-efficient, native dtype)
# All computation stays in model dtype (bf16) instead of casting to fp32.
# Uses einsum contractions to avoid materializing massive intermediate
# tensors (G_intermediate was 68GB in fp32, now eliminated entirely).
A = -torch.exp(self.A_log.to(dtype))
dt = F.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size)
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2)
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2)
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Reshape into chunks
# hidden_states: [batch, n_chunks, chunk, heads, head_dim]
# A: [batch, n_chunks, chunk, heads]
# B, C: [batch, n_chunks, chunk, heads, ssm_state]
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
A = A.permute(0, 3, 1, 2) # [batch, heads, n_chunks, chunk]
A_cumsum = torch.cumsum(A, dim=-1)
# === Intra-chunk (diagonal blocks) — memory-efficient ===
# Decay matrix L via cumsum difference — replaces segment_sum which
# expanded [chunk] → [chunk, chunk] via O(n^2) broadcast.
# Math: L[i,j] = exp(A_cumsum[i] - A_cumsum[j]) for j <= i
# Mask BEFORE exp to avoid inf*0=NaN in bf16 (upper triangle overflows)
L_arg = A_cumsum[..., :, None] - A_cumsum[..., None, :]
causal_mask = torch.tril(torch.ones(
self.chunk_size, self.chunk_size, device=L_arg.device, dtype=torch.bool))
L = torch.exp(L_arg.masked_fill(~causal_mask, float('-inf')))
# Contract ssm_state via einsum FIRST — avoids materializing the
# [chunk, chunk, heads, state] outer product (was 68GB in fp32).
# Result: [batch, n_chunks, chunk_i, chunk_j, heads] = ~268MB in bf16
G = torch.einsum('bnchs, bnkhs -> bnckh', C, B)
M = G * L.permute(0, 2, 3, 4, 1)
# Batched matmul contracts chunk_j without materializing
# [chunk, chunk, heads, head_dim] (was 17GB).
Y_diag = torch.einsum('bnijh, bnjhd -> bnihd', M, hidden_states)
# === Inter-chunk state recurrence ===
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
# Contract chunk_size via einsum (was 8.6GB materialization)
states = torch.einsum('bnchs, bnchd -> bnhds', B_decay, hidden_states)
if cache_params is not None and cache_params.has_previous_state:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
# Inter-chunk decay via cumsum difference (n_chunks is small, ~16)
# Mask BEFORE exp to avoid inf*0=NaN in bf16
chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1)
n_plus1 = chunk_cumA.shape[-1]
decay_arg = chunk_cumA[..., :, None] - chunk_cumA[..., None, :]
chunk_mask = torch.tril(torch.ones(
n_plus1, n_plus1, device=decay_arg.device, dtype=torch.bool))
decay_chunk = torch.exp(decay_arg.masked_fill(~chunk_mask, float('-inf')))
decay_chunk = decay_chunk.transpose(1, 3)
# Contract n_chunks+1 via einsum
new_states = torch.einsum('bijh, bjhds -> bihds', decay_chunk, states)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# === State → output ===
state_decay_out = torch.exp(A_cumsum)
# Contract ssm_state via einsum (was 8.6GB materialization)
Y_off = torch.einsum('bnchs, bnhds -> bnchd', C, states)
Y_off = Y_off * state_decay_out.permute(0, 2, 3, 1)[..., None]
y = Y_diag + Y_off
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
# Cache state
if ssm_state is not None and cache_params is not None:
cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
scan_output = self.norm(y, gate)
contextualized_states = self.out_proj(scan_output.to(dtype))
return contextualized_states
# fmt: on
def forward(
self,
hidden_states,
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
# Single-step generation with cache
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
return self._single_step_forward(hidden_states, cache_params, attention_mask)
# Full sequence (training or prefill)
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return self._chunked_forward(hidden_states, cache_params, attention_mask)
# ─── RMSNorm ─────────────────────────────────────────────────────────────────
class NemotronHRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
# ─── Block / Attention / MLP (unchanged from original) ───────────────────────
class NemotronHBlock(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
block_type = config.hybrid_override_pattern[layer_idx]
if block_type == "M":
self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx)
elif block_type == "*":
self.mixer = NemotronHAttention(config, layer_idx=layer_idx)
elif block_type == "-":
self.mixer = NemotronHMLP(config, layer_idx=layer_idx)
else:
raise ValueError(f"Unknown block type: {block_type}")
# Properties (not nn.Module assignments) so tooling like PEFT
# can access layer.self_attn / layer.mlp without PyTorch registering
# duplicate submodules that corrupt weight loading.
@property
def self_attn(self):
return self.mixer
@property
def mlp(self):
return self.mixer
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
cache_params=None,
cache_position=None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
block_type = self.config.hybrid_override_pattern[self.layer_idx]
if block_type == "M":
hidden_states = self.mixer(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
)
elif block_type == "*":
hidden_states = self.mixer(
hidden_states,
attention_mask=attention_mask,
past_key_value=cache_params,
cache_position=cache_position,
)
hidden_states = hidden_states[0]
elif block_type == "-":
hidden_states = self.mixer(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class NemotronHMLP(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.mlp_hidden_act]
@property
def o_proj(self):
return self.down_proj
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class NemotronHRotaryEmbedding(nn.Module):
def __init__(self, config=None, device=None):
super().__init__()
self.rope_kwargs = {}
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = self._default_rope_init
self.rope_init_fn(self.config, device)
self.original_inv_freq = self.inv_freq
def _default_rope_init(self, config, device=None):
base = 10000.0
dim = config.head_dim
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class NemotronHAttention(nn.Module):
def __init__(self, config: NemotronHConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
cache_position=None,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# No RoPE — Nemotron-H attention has no rotary embeddings.
# KV cache for autoregressive generation
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
# GQA expansion
key_states = key_states[:, :, None, :, :].expand(-1, -1, self.num_key_value_groups, -1, -1)
key_states = key_states.reshape(bsz, self.num_heads, -1, self.head_dim)
value_states = value_states[:, :, None, :, :].expand(-1, -1, self.num_key_value_groups, -1, -1)
value_states = value_states.reshape(bsz, self.num_heads, -1, self.head_dim)
# Use causal mask from model if available, otherwise infer
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# ─── Model ────────────────────────────────────────────────────────────────────
@dataclass
class NemotronHOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
cache_params: Optional[HybridMambaAttentionDynamicCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class NemotronHCausalLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
cache_params: Optional[HybridMambaAttentionDynamicCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class NemotronHPreTrainedModel(PreTrainedModel):
config_class = NemotronHConfig
base_model_prefix = "backbone"
_no_split_modules = ["NemotronHBlock"]
supports_gradient_checkpointing = True
_is_stateful = True
# Ignore any checkpoint keys that came through property-alias paths
# (e.g. from a checkpoint saved when self_attn/mlp were nn.Module aliases)
_keys_to_ignore_on_load_unexpected = [
r"backbone\.layers\.\d+\.self_attn\.",
r"backbone\.layers\.\d+\.mlp\.",
]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _initialize_missing_keys(self, is_quantized=False):
"""Compatible with both old (missing_keys, is_quantized) and new (is_quantized) API."""
pass
class NemotronHModel(NemotronHPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
)
self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def forward(
self,
input_ids=None,
inputs_embeds=None,
cache_params=None,
use_cache=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
attention_mask=None,
**kwargs,
):
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if hasattr(self.config, 'output_hidden_states') else False
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache and cache_params is None:
cache_params = HybridMambaAttentionDynamicCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
# Position IDs for attention layers
if cache_position is None:
cache_position = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
position_ids = cache_position[None, :].expand(hidden_states.shape[0], -1)
# Causal attention mask
causal_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, cache_params)
all_hidden_states = () if output_hidden_states else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
causal_mask,
position_ids,
cache_params,
cache_position,
)
else:
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
cache_params=cache_params,
cache_position=cache_position,
)
if use_cache:
cache_params.has_previous_state = True
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return NemotronHOutput(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, cache_params):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# During generation, attention_mask from generate() already covers all
# past + current tokens — use its length as the authoritative target.
if attention_mask is not None and attention_mask.dim() == 2:
target_length = attention_mask.shape[-1]
elif cache_params is not None and cache_params.has_previous_state:
target_length = cache_params.get_seq_length() + sequence_length
else:
target_length = sequence_length
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=target_length - sequence_length + 1)
# Zero out past/current positions so each query attends to all keys <= its position.
# Critical for single-token generation (seq_len=1) where triu is skipped.
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + (1.0 - attention_mask[:, None, None, :].to(causal_mask.dtype)) * min_dtype
causal_mask = torch.cat([padding_mask, causal_mask[:, :, :, mask_length:]], dim=-1) if mask_length < target_length else padding_mask
return causal_mask
class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.backbone = NemotronHModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
@property
def model(self):
"""Alias so tooling that expects .model (LoRA, PEFT, etc.) works."""
return self.backbone
@model.setter
def model(self, value):
self.backbone = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, **kwargs):
model_kwargs["cache_params"] = outputs.get("cache_params", None)
if "cache_position" in model_kwargs:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
# Extend attention_mask by 1 each generation step (HF default does this
# in super() but we override, so we must do it manually)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = torch.cat([
model_kwargs["attention_mask"],
model_kwargs["attention_mask"].new_ones((model_kwargs["attention_mask"].shape[0], 1)),
], dim=-1)
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids,
cache_params=None,
inputs_embeds=None,
attention_mask=None,
cache_position=None,
**kwargs,
):
if cache_params is not None:
if input_ids.shape[1] != 1:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"cache_params": cache_params,
"cache_position": cache_position,
"attention_mask": attention_mask,
})
return model_inputs
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
cache_params=None,
labels=None,
output_hidden_states=None,
return_dict=None,
use_cache=None,
cache_position=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True
nemotron_h_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = nemotron_h_outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + nemotron_h_outputs[1:]
return (loss,) + output if loss is not None else output
return NemotronHCausalLMOutput(
loss=loss,
logits=logits,
cache_params=nemotron_h_outputs.cache_params if return_dict else None,
hidden_states=nemotron_h_outputs.hidden_states if return_dict else None,
)