| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| openNemo — Pure-PyTorch NemotronH model. |
| |
| Drop-in replacement for nvidia's modeling_nemotron_h.py that removes ALL |
| external CUDA kernel dependencies (mamba-ssm, causal-conv1d). This makes |
| the model fully compatible with bitsandbytes quantization (4-bit / 8-bit) |
| and trainable on consumer GPUs with QLoRA. |
| |
| Changes from original: |
| - Removed: mamba_ssm imports (selective_state_update, ssd_combined, rmsnorm_fn) |
| - Removed: causal_conv1d imports (causal_conv1d_fn, causal_conv1d_update) |
| - Rewrote: MambaRMSNormGated → pure PyTorch (no rmsnorm_fn) |
| - Rewrote: NemotronHMamba2Mixer.cuda_kernels_forward → removed entirely |
| - Rewrote: NemotronHMamba2Mixer.torch_forward → optimized chunked scan |
| - Rewrote: forward() routing → always uses torch_forward (no fast_path check) |
| - Added: causal_conv1d_naive for causal 1D convolution |
| - All weights are binary-compatible — load original checkpoints directly. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| import torch.nn.functional as F |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| ModelOutput, |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| ) |
| from transformers.utils.import_utils import ( |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| ) |
| from .configuration_nemotron_h import NemotronHConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| if is_flash_attn_2_available(): |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward |
|
|
| _CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" |
| _CONFIG_FOR_DOC = "NemotronHConfig" |
|
|
|
|
| |
|
|
| def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): |
| """Padding x tensor with `pad_size` on the seq_len dim (dim=1).""" |
| pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) |
| return F.pad(input_tensor, pad_shape, mode="constant", value=0) |
|
|
|
|
| def reshape_into_chunks(input_tensor, pad_size, chunk_size): |
| """Pad and reshape into chunks along seq_len dim.""" |
| input_tensor = pad_tensor_by_size(input_tensor, pad_size) |
| if len(input_tensor.shape) == 3: |
| return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) |
| else: |
| return input_tensor.reshape( |
| input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] |
| ) |
|
|
|
|
| def segment_sum(input_tensor): |
| """Stable segment sum via cumulative sums and masking.""" |
| chunk_size = input_tensor.size(-1) |
| input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) |
| mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) |
| input_tensor = input_tensor.masked_fill(~mask, 0) |
| tensor_segsum = torch.cumsum(input_tensor, dim=-2) |
| mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) |
| tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) |
| return tensor_segsum |
|
|
|
|
| def apply_mask_to_padding_states(hidden_states, attention_mask): |
| """Zero out hidden states for padding tokens.""" |
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: |
| dtype = hidden_states.dtype |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) |
| return hidden_states |
|
|
|
|
| def causal_conv1d_naive(x, weight, bias=None, activation="silu"): |
| """ |
| Pure-PyTorch causal 1D depthwise convolution. |
| x: (batch, channels, seq_len) |
| weight: (channels, kernel_size) |
| bias: (channels,) or None |
| Returns: (batch, channels, seq_len) |
| """ |
| channels, kernel_size = weight.shape |
| |
| x_padded = F.pad(x, (kernel_size - 1, 0)) |
| |
| weight_conv = weight.unsqueeze(1) |
| out = F.conv1d(x_padded, weight_conv, bias=bias, groups=channels) |
| if activation in ("silu", "swish"): |
| out = F.silu(out) |
| return out |
|
|
|
|
| def rms_norm_gated(hidden_states, weight, gate=None, eps=1e-5, group_size=None): |
| """ |
| Pure-PyTorch gated RMSNorm — replaces mamba_ssm's rmsnorm_fn. |
| norm_before_gate=False (matching NVIDIA's original): gate first, then normalize. |
| """ |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
|
|
| |
| if gate is not None: |
| hidden_states = hidden_states * F.silu(gate.to(torch.float32)) |
|
|
| if group_size is not None and group_size < hidden_states.shape[-1]: |
| |
| orig_shape = hidden_states.shape |
| hidden_states = hidden_states.reshape(*orig_shape[:-1], -1, group_size) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + eps) |
| hidden_states = hidden_states.reshape(orig_shape) |
| else: |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + eps) |
|
|
| hidden_states = weight.to(torch.float32) * hidden_states |
|
|
| return hidden_states.to(input_dtype) |
|
|
|
|
| |
|
|
| class HybridMambaAttentionDynamicCache(DynamicCache): |
| """ |
| Cache for hybrid Mamba-Attention model. Handles both attention KV cache |
| and Mamba conv/SSM state cache. |
| """ |
|
|
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): |
| super().__init__() |
| self.dtype = dtype |
| self.hybrid_override_pattern = config.hybrid_override_pattern |
| self.has_previous_state = False |
| intermediate_size = config.mamba_num_heads * config.mamba_head_dim |
| ssm_state_size = config.ssm_state_size |
| conv_kernel_size = config.conv_kernel |
| self.conv_kernel_size = conv_kernel_size |
| self.conv_states = [] |
| self.ssm_states = [] |
| self.transformer_layers = [] |
| for i in range(config.num_hidden_layers): |
| if self.hybrid_override_pattern[i] == "M": |
| self.conv_states += [ |
| torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) |
| ] |
| self.ssm_states += [ |
| torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) |
| ] |
| else: |
| self.conv_states += [torch.tensor([[]] * batch_size, device=device)] |
| self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] |
| self.transformer_layers.append(i) |
|
|
| self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] |
| self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] |
|
|
| def update(self, key_states, value_states, layer_idx, cache_kwargs=None): |
| if self.key_cache[layer_idx].shape[-1] == 0: |
| self.key_cache[layer_idx] = key_states |
| self.value_cache[layer_idx] = value_states |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| def reorder_cache(self, beam_idx): |
| for layer_idx in range(len(self.key_cache)): |
| device = self.key_cache[layer_idx].device |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.value_cache[layer_idx].device |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.conv_states[layer_idx].device |
| self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.ssm_states[layer_idx].device |
| self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
| def get_seq_length(self, layer_idx=0): |
| layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx |
| if len(self.key_cache) <= layer_idx: |
| return 0 |
| return self.key_cache[layer_idx].shape[-2] |
|
|
| def to_legacy_cache(self): |
| raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") |
|
|
| @classmethod |
| def from_legacy_cache(cls, past_key_values=None): |
| raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") |
|
|
| def update_conv_state(self, layer_idx, new_conv_state, cache_init=False): |
| if cache_init: |
| self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) |
| else: |
| self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) |
| self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device) |
| return self.conv_states[layer_idx] |
|
|
| def update_ssm_state(self, layer_idx, new_ssm_state): |
| self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) |
| return self.ssm_states[layer_idx] |
|
|
| def reset(self): |
| for s in self.conv_states: |
| if s.numel() > 0: |
| s.zero_() |
| for s in self.ssm_states: |
| if s.numel() > 0: |
| s.zero_() |
|
|
|
|
| |
|
|
| class MambaRMSNormGated(nn.Module): |
| def __init__(self, hidden_size, group_size=None, eps=1e-5): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| self.group_size = group_size if group_size is not None else hidden_size |
|
|
| def forward(self, hidden_states, gate=None): |
| return rms_norm_gated( |
| hidden_states, |
| self.weight, |
| gate=gate, |
| eps=self.variance_epsilon, |
| group_size=self.group_size, |
| ) |
|
|
|
|
| |
|
|
| class NemotronHMamba2Mixer(nn.Module): |
| """ |
| Pure-PyTorch Mamba2 SSM mixer. Weight-compatible with the original |
| NVIDIA implementation but uses no external CUDA kernels. |
| """ |
|
|
| def __init__(self, config: NemotronHConfig, layer_idx: int): |
| super().__init__() |
| self.num_heads = config.mamba_num_heads |
| self.hidden_size = config.hidden_size |
| self.ssm_state_size = config.ssm_state_size |
| self.conv_kernel_size = config.conv_kernel |
| self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim |
| self.layer_idx = layer_idx |
| self.use_conv_bias = config.use_conv_bias |
| self.activation = config.mamba_hidden_act |
| self.act = ACT2FN[config.mamba_hidden_act] |
|
|
| self.layer_norm_epsilon = config.layer_norm_epsilon |
| self.n_groups = config.n_groups |
| self.head_dim = config.mamba_head_dim |
| self.chunk_size = config.chunk_size |
|
|
| self.time_step_limit = config.time_step_limit |
| self.time_step_min = config.time_step_min |
| self.time_step_max = config.time_step_max |
|
|
| self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size |
| self.conv1d = nn.Conv1d( |
| in_channels=self.conv_dim, |
| out_channels=self.conv_dim, |
| bias=config.use_conv_bias, |
| kernel_size=config.conv_kernel, |
| groups=self.conv_dim, |
| padding=config.conv_kernel - 1, |
| ) |
|
|
| projection_size = self.intermediate_size + self.conv_dim + self.num_heads |
| self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=config.use_bias) |
|
|
| self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) |
| A = torch.arange(1, self.num_heads + 1) |
| self.A_log = nn.Parameter(torch.log(A)) |
| self.A_log._no_weight_decay = True |
| self.norm = MambaRMSNormGated( |
| self.intermediate_size, |
| eps=self.layer_norm_epsilon, |
| group_size=self.intermediate_size // self.n_groups, |
| ) |
| self.D = nn.Parameter(torch.ones(self.num_heads)) |
| self.D._no_weight_decay = True |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) |
| self.use_bias = config.use_bias |
|
|
| @property |
| def o_proj(self): |
| """Alias for tooling that expects o_proj.""" |
| return self.out_proj |
|
|
| def _single_step_forward( |
| self, |
| hidden_states, |
| cache_params, |
| attention_mask=None, |
| ): |
| """Single token generation step with cache.""" |
| batch_size = hidden_states.shape[0] |
| groups_time_state_size = self.n_groups * self.ssm_state_size |
|
|
| hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) |
| projected_states = self.in_proj(hidden_states) |
|
|
| d_mlp = ( |
| projected_states.shape[-1] |
| - 2 * self.intermediate_size |
| - 2 * self.n_groups * self.ssm_state_size |
| - self.num_heads |
| ) // 2 |
|
|
| _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( |
| [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 |
| ) |
|
|
| |
| conv_state = cache_params.conv_states[self.layer_idx] |
| conv_state = conv_state.roll(shifts=-1, dims=-1) |
| conv_state[:, :, -1] = hidden_states_B_C |
| cache_params.conv_states[self.layer_idx] = conv_state |
|
|
| hidden_states_B_C = torch.sum( |
| conv_state * self.conv1d.weight.squeeze(1), dim=-1 |
| ) |
| if self.use_conv_bias: |
| hidden_states_B_C = hidden_states_B_C + self.conv1d.bias |
| hidden_states_B_C = self.act(hidden_states_B_C) |
|
|
| hidden_states_inner, B, C = torch.split( |
| hidden_states_B_C, |
| [self.intermediate_size, groups_time_state_size, groups_time_state_size], |
| dim=-1, |
| ) |
|
|
| |
| A = -torch.exp(self.A_log.float()) |
| A = A[:, None, ...].expand(-1, self.head_dim)[:, :, None].expand(-1, -1, self.ssm_state_size).to(torch.float32) |
| dt_expanded = dt[:, :, None].expand(-1, -1, self.head_dim) |
| dt_bias = self.dt_bias[:, None].expand(-1, self.head_dim) |
| D = self.D[:, None].expand(-1, self.head_dim) |
|
|
| B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) |
| C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) |
| hidden_reshaped = hidden_states_inner.view(batch_size, self.num_heads, self.head_dim) |
|
|
| |
| dt_with_bias = F.softplus(dt_expanded + dt_bias) |
| dt_with_bias = torch.clamp(dt_with_bias, self.time_step_limit[0], self.time_step_limit[1]) |
| dA = torch.exp(dt_with_bias.unsqueeze(-1) * A.unsqueeze(0)) |
|
|
| |
| B_expanded = B[:, :, None, :].expand(-1, -1, self.num_heads // self.n_groups, -1) |
| B_expanded = B_expanded.reshape(batch_size, self.num_heads, self.ssm_state_size) |
|
|
| dBx = dt_with_bias.unsqueeze(-1) * B_expanded.unsqueeze(2) * hidden_reshaped.unsqueeze(-1) |
|
|
| ssm_state = cache_params.ssm_states[self.layer_idx] |
| ssm_state = ssm_state.to(dA.device, dtype=dA.dtype) |
| new_ssm_state = ssm_state * dA + dBx |
| cache_params.ssm_states[self.layer_idx] = new_ssm_state.to(cache_params.ssm_states[self.layer_idx].dtype) |
|
|
| |
| C_expanded = C[:, :, None, :].expand(-1, -1, self.num_heads // self.n_groups, -1) |
| C_expanded = C_expanded.reshape(batch_size, self.num_heads, self.ssm_state_size) |
|
|
| y = (new_ssm_state.to(C_expanded.dtype) * C_expanded.unsqueeze(2)).sum(-1) |
| y = y + hidden_reshaped * D.unsqueeze(0) |
| y = y.reshape(batch_size, -1) |
| y = self.norm(y, gate) |
| out = self.out_proj(y)[:, None, ...] |
| return out |
|
|
| |
| def _chunked_forward(self, input_states, cache_params=None, attention_mask=None): |
| """ |
| Full sequence forward pass using chunked SSD scan. |
| This is the torch_forward from the original, which works correctly |
| with bitsandbytes quantization. |
| """ |
| batch_size, seq_len, _ = input_states.shape |
| dtype = input_states.dtype |
| groups_time_state_size = self.n_groups * self.ssm_state_size |
|
|
| |
| input_states = apply_mask_to_padding_states(input_states, attention_mask) |
| projected_states = self.in_proj(input_states) |
| d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 |
| _, _, gate, hidden_states_B_C, dt = projected_states.split( |
| [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 |
| ) |
|
|
| |
| if cache_params is not None: |
| hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) |
| conv_states = F.pad( |
| hidden_states_B_C_transposed, |
| (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), |
| ) |
| cache_params.update_conv_state( |
| layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True |
| ) |
|
|
| |
| hidden_states_B_C = self.act( |
| self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) |
| ) |
|
|
| hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) |
| hidden_states, B, C = torch.split( |
| hidden_states_B_C, |
| [self.intermediate_size, groups_time_state_size, groups_time_state_size], |
| dim=-1, |
| ) |
|
|
| |
| |
| |
| |
| A = -torch.exp(self.A_log.to(dtype)) |
| dt = F.softplus(dt + self.dt_bias) |
| dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) |
| hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim) |
| B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size) |
| C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size) |
| B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2) |
| C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2) |
| pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size |
|
|
| D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) |
|
|
| |
| hidden_states = hidden_states * dt[..., None] |
| A = A.to(hidden_states.dtype) * dt |
|
|
| |
| |
| |
| |
| hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] |
|
|
| A = A.permute(0, 3, 1, 2) |
| A_cumsum = torch.cumsum(A, dim=-1) |
|
|
| |
| |
| |
| |
| |
| L_arg = A_cumsum[..., :, None] - A_cumsum[..., None, :] |
| causal_mask = torch.tril(torch.ones( |
| self.chunk_size, self.chunk_size, device=L_arg.device, dtype=torch.bool)) |
| L = torch.exp(L_arg.masked_fill(~causal_mask, float('-inf'))) |
|
|
| |
| |
| |
| G = torch.einsum('bnchs, bnkhs -> bnckh', C, B) |
| M = G * L.permute(0, 2, 3, 4, 1) |
|
|
| |
| |
| Y_diag = torch.einsum('bnijh, bnjhd -> bnihd', M, hidden_states) |
|
|
| |
| decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) |
| B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] |
|
|
| |
| states = torch.einsum('bnchs, bnchd -> bnhds', B_decay, hidden_states) |
|
|
| if cache_params is not None and cache_params.has_previous_state: |
| previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) |
| else: |
| previous_states = torch.zeros_like(states[:, :1]) |
| states = torch.cat([previous_states, states], dim=1) |
|
|
| |
| |
| chunk_cumA = torch.cumsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), dim=-1) |
| n_plus1 = chunk_cumA.shape[-1] |
| decay_arg = chunk_cumA[..., :, None] - chunk_cumA[..., None, :] |
| chunk_mask = torch.tril(torch.ones( |
| n_plus1, n_plus1, device=decay_arg.device, dtype=torch.bool)) |
| decay_chunk = torch.exp(decay_arg.masked_fill(~chunk_mask, float('-inf'))) |
| decay_chunk = decay_chunk.transpose(1, 3) |
|
|
| |
| new_states = torch.einsum('bijh, bjhds -> bihds', decay_chunk, states) |
| states, ssm_state = new_states[:, :-1], new_states[:, -1] |
|
|
| |
| state_decay_out = torch.exp(A_cumsum) |
| |
| Y_off = torch.einsum('bnchs, bnhds -> bnchd', C, states) |
| Y_off = Y_off * state_decay_out.permute(0, 2, 3, 1)[..., None] |
|
|
| y = Y_diag + Y_off |
| y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) |
| y = y + D_residual |
| if pad_size > 0: |
| y = y[:, :seq_len, :, :] |
| y = y.reshape(batch_size, seq_len, -1) |
|
|
| |
| if ssm_state is not None and cache_params is not None: |
| cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) |
|
|
| scan_output = self.norm(y, gate) |
| contextualized_states = self.out_proj(scan_output.to(dtype)) |
| return contextualized_states |
| |
|
|
| def forward( |
| self, |
| hidden_states, |
| cache_params: Optional[HybridMambaAttentionDynamicCache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ): |
| |
| if cache_params is not None and cache_position is not None and cache_position[0] > 0: |
| return self._single_step_forward(hidden_states, cache_params, attention_mask) |
|
|
| |
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: |
| dtype = hidden_states.dtype |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) |
|
|
| return self._chunked_forward(hidden_states, cache_params, attention_mask) |
|
|
|
|
| |
|
|
| class NemotronHRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) |
|
|
|
|
| |
|
|
| class NemotronHBlock(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.residual_in_fp32 = config.residual_in_fp32 |
| self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| block_type = config.hybrid_override_pattern[layer_idx] |
| if block_type == "M": |
| self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) |
| elif block_type == "*": |
| self.mixer = NemotronHAttention(config, layer_idx=layer_idx) |
| elif block_type == "-": |
| self.mixer = NemotronHMLP(config, layer_idx=layer_idx) |
| else: |
| raise ValueError(f"Unknown block type: {block_type}") |
|
|
| |
| |
| |
| @property |
| def self_attn(self): |
| return self.mixer |
|
|
| @property |
| def mlp(self): |
| return self.mixer |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| cache_params=None, |
| cache_position=None, |
| ): |
| residual = hidden_states |
| hidden_states = self.norm(hidden_states) |
| if self.residual_in_fp32: |
| residual = residual.to(torch.float32) |
|
|
| block_type = self.config.hybrid_override_pattern[self.layer_idx] |
| if block_type == "M": |
| hidden_states = self.mixer( |
| hidden_states, |
| cache_params=cache_params, |
| cache_position=cache_position, |
| ) |
| elif block_type == "*": |
| hidden_states = self.mixer( |
| hidden_states, |
| attention_mask=attention_mask, |
| past_key_value=cache_params, |
| cache_position=cache_position, |
| ) |
| hidden_states = hidden_states[0] |
| elif block_type == "-": |
| hidden_states = self.mixer(hidden_states) |
|
|
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class NemotronHMLP(nn.Module): |
| def __init__(self, config, layer_idx=None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) |
| self.act_fn = ACT2FN[config.mlp_hidden_act] |
|
|
| @property |
| def o_proj(self): |
| return self.down_proj |
|
|
| def forward(self, x): |
| return self.down_proj(self.act_fn(self.up_proj(x))) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class NemotronHRotaryEmbedding(nn.Module): |
| def __init__(self, config=None, device=None): |
| super().__init__() |
| self.rope_kwargs = {} |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
| self.config = config |
| self.rope_init_fn = self._default_rope_init |
| self.rope_init_fn(self.config, device) |
| self.original_inv_freq = self.inv_freq |
|
|
| def _default_rope_init(self, config, device=None): |
| base = 10000.0 |
| dim = config.head_dim |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| @torch.no_grad() |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| position_ids_expanded = position_ids[:, None, :].float() |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class NemotronHAttention(nn.Module): |
| def __init__(self, config: NemotronHConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| past_key_value=None, |
| output_attentions=False, |
| use_cache=False, |
| cache_position=None, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| |
|
|
| |
| if past_key_value is not None: |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
| |
| key_states = key_states[:, :, None, :, :].expand(-1, -1, self.num_key_value_groups, -1, -1) |
| key_states = key_states.reshape(bsz, self.num_heads, -1, self.head_dim) |
| value_states = value_states[:, :, None, :, :].expand(-1, -1, self.num_key_value_groups, -1, -1) |
| value_states = value_states.reshape(bsz, self.num_heads, -1, self.head_dim) |
|
|
| |
| causal_mask = attention_mask |
| if causal_mask is not None: |
| causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
| is_causal = True if causal_mask is None and q_len > 1 else False |
| attn_output = F.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=causal_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) |
| attn_output = self.o_proj(attn_output) |
| return attn_output, None, past_key_value |
|
|
|
|
| |
|
|
| @dataclass |
| class NemotronHOutput(ModelOutput): |
| last_hidden_state: torch.FloatTensor = None |
| cache_params: Optional[HybridMambaAttentionDynamicCache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| @dataclass |
| class NemotronHCausalLMOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| cache_params: Optional[HybridMambaAttentionDynamicCache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| class NemotronHPreTrainedModel(PreTrainedModel): |
| config_class = NemotronHConfig |
| base_model_prefix = "backbone" |
| _no_split_modules = ["NemotronHBlock"] |
| supports_gradient_checkpointing = True |
| _is_stateful = True |
| |
| |
| _keys_to_ignore_on_load_unexpected = [ |
| r"backbone\.layers\.\d+\.self_attn\.", |
| r"backbone\.layers\.\d+\.mlp\.", |
| ] |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def _initialize_missing_keys(self, is_quantized=False): |
| """Compatible with both old (missing_keys, is_quantized) and new (is_quantized) API.""" |
| pass |
|
|
|
|
| class NemotronHModel(NemotronHPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] |
| ) |
| self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.embeddings = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| inputs_embeds=None, |
| cache_params=None, |
| use_cache=None, |
| output_hidden_states=None, |
| return_dict=None, |
| cache_position=None, |
| attention_mask=None, |
| **kwargs, |
| ): |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if hasattr(self.config, 'output_hidden_states') else False |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embeddings(input_ids) |
|
|
| hidden_states = inputs_embeds |
|
|
| if use_cache and cache_params is None: |
| cache_params = HybridMambaAttentionDynamicCache( |
| self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype |
| ) |
|
|
| |
| if cache_position is None: |
| cache_position = torch.arange(0, hidden_states.shape[1], device=hidden_states.device) |
| position_ids = cache_position[None, :].expand(hidden_states.shape[0], -1) |
|
|
| |
| causal_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, cache_params) |
|
|
| all_hidden_states = () if output_hidden_states else None |
|
|
| for layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func( |
| layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| cache_params, |
| cache_position, |
| ) |
| else: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| cache_params=cache_params, |
| cache_position=cache_position, |
| ) |
|
|
| if use_cache: |
| cache_params.has_previous_state = True |
|
|
| hidden_states = self.norm_f(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) |
|
|
| return NemotronHOutput( |
| last_hidden_state=hidden_states, |
| cache_params=cache_params if use_cache else None, |
| hidden_states=all_hidden_states, |
| ) |
|
|
| def _update_causal_mask(self, attention_mask, input_tensor, cache_position, cache_params): |
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
|
|
| |
| |
| if attention_mask is not None and attention_mask.dim() == 2: |
| target_length = attention_mask.shape[-1] |
| elif cache_params is not None and cache_params.has_previous_state: |
| target_length = cache_params.get_seq_length() + sequence_length |
| else: |
| target_length = sequence_length |
|
|
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=target_length - sequence_length + 1) |
| |
| |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) |
|
|
| if attention_mask is not None and attention_mask.dim() == 2: |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + (1.0 - attention_mask[:, None, None, :].to(causal_mask.dtype)) * min_dtype |
| causal_mask = torch.cat([padding_mask, causal_mask[:, :, :, mask_length:]], dim=-1) if mask_length < target_length else padding_mask |
|
|
| return causal_mask |
|
|
|
|
| class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.backbone = NemotronHModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| @property |
| def model(self): |
| """Alias so tooling that expects .model (LoRA, PEFT, etc.) works.""" |
| return self.backbone |
|
|
| @model.setter |
| def model(self, value): |
| self.backbone = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def get_input_embeddings(self): |
| return self.backbone.get_input_embeddings() |
|
|
| def set_input_embeddings(self, new_embeddings): |
| return self.backbone.set_input_embeddings(new_embeddings) |
|
|
| def _update_model_kwargs_for_generation(self, outputs, model_kwargs, **kwargs): |
| model_kwargs["cache_params"] = outputs.get("cache_params", None) |
| if "cache_position" in model_kwargs: |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
| |
| |
| if "attention_mask" in model_kwargs: |
| model_kwargs["attention_mask"] = torch.cat([ |
| model_kwargs["attention_mask"], |
| model_kwargs["attention_mask"].new_ones((model_kwargs["attention_mask"].shape[0], 1)), |
| ], dim=-1) |
| return model_kwargs |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| cache_params=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| cache_position=None, |
| **kwargs, |
| ): |
| if cache_params is not None: |
| if input_ids.shape[1] != 1: |
| input_ids = input_ids[:, -1:] |
|
|
| if inputs_embeds is not None and cache_params is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update({ |
| "cache_params": cache_params, |
| "cache_position": cache_position, |
| "attention_mask": attention_mask, |
| }) |
| return model_inputs |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| inputs_embeds=None, |
| cache_params=None, |
| labels=None, |
| output_hidden_states=None, |
| return_dict=None, |
| use_cache=None, |
| cache_position=None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict if hasattr(self.config, 'use_return_dict') else True |
|
|
| nemotron_h_outputs = self.backbone( |
| input_ids, |
| cache_params=cache_params, |
| inputs_embeds=inputs_embeds, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| attention_mask=attention_mask, |
| ) |
| hidden_states = nemotron_h_outputs[0] |
|
|
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + nemotron_h_outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return NemotronHCausalLMOutput( |
| loss=loss, |
| logits=logits, |
| cache_params=nemotron_h_outputs.cache_params if return_dict else None, |
| hidden_states=nemotron_h_outputs.hidden_states if return_dict else None, |
| ) |
|
|