from __future__ import annotations import math import os from typing import Any, cast import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel try: from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS except ImportError: ROPE_INIT_FUNCTIONS = {} try: from fla.modules import FusedRMSNormGated, ShortConvolution from fla.ops.gated_delta_rule import ( chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, ) except ImportError: chunk_gated_delta_rule = None fused_recurrent_gated_delta_rule = None FusedRMSNormGated = None ShortConvolution = None from .configuration_lizzy import LizzyConfig class LizzyRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) return self.weight * hidden_states.to(input_dtype) def _make_norm( norm_type: str, hidden_size: int, eps: float, *, has_bias: bool, ) -> nn.Module: if norm_type == "rmsnorm": return LizzyRMSNorm(hidden_size, eps=eps) if norm_type == "layernorm": return nn.LayerNorm( hidden_size, eps=eps, elementwise_affine=True, bias=has_bias, ) msg = f"Unsupported norm_type: {norm_type}" raise ValueError(msg) def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: q_embed = (q * cos) + (_rotate_half(q) * sin) k_embed = (k * cos) + (_rotate_half(k) * sin) return q_embed, k_embed def _legacy_cache_length( past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None, ) -> int: if ( isinstance(past_key_values, tuple) and len(past_key_values) > 0 and past_key_values[0] is not None and past_key_values[0][0] is not None ): return int(past_key_values[0][0].shape[2]) return 0 def _normalize_cache_position( cache_position: torch.Tensor | None, ) -> torch.Tensor | None: if cache_position is None: return None if cache_position.dim() == 0: return cache_position.view(1) if cache_position.dim() > 1: return cache_position[0] return cache_position def _is_cache_object(value: Any) -> bool: return isinstance(value, Cache) or isinstance(value, LizzyHybridDynamicCache) def _compute_default_rope_parameters( config: LizzyConfig, device: torch.device, ) -> tuple[torch.Tensor, float]: inv_freq = 1.0 / ( config.rope_theta ** ( torch.arange(0, config.head_dim, 2, device=device, dtype=torch.float32) / config.head_dim ) ) return inv_freq, 1.0 def _compute_yarn_rope_parameters( config: LizzyConfig, device: torch.device, ) -> tuple[torch.Tensor, float]: rope_scaling = dict(config.rope_scaling or {}) factor = float(rope_scaling["factor"]) attention_factor = rope_scaling.get("attention_factor") mscale = rope_scaling.get("mscale") mscale_all_dim = rope_scaling.get("mscale_all_dim") original_max_position_embeddings = int( rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings ) def get_mscale(scale: float, mscale_value: float = 1.0) -> float: if scale <= 1.0: return 1.0 return 0.1 * mscale_value * math.log(scale) + 1.0 if attention_factor is None: if mscale is not None and mscale_all_dim is not None: attention_factor = float( get_mscale(factor, float(mscale)) / get_mscale(factor, float(mscale_all_dim)) ) else: attention_factor = get_mscale(factor) beta_fast = float(rope_scaling.get("beta_fast") or 32.0) beta_slow = float(rope_scaling.get("beta_slow") or 1.0) truncate = bool(rope_scaling.get("truncate", True)) dim = config.head_dim def find_correction_dim( num_rotations: float, *, dim: int, base: float, max_position_embeddings: int, ) -> float: return ( dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) ) def find_correction_range( low_rot: float, high_rot: float, *, dim: int, base: float, max_position_embeddings: int, truncate: bool, ) -> tuple[float, float]: low = find_correction_dim( low_rot, dim=dim, base=base, max_position_embeddings=max_position_embeddings, ) high = find_correction_dim( high_rot, dim=dim, base=base, max_position_embeddings=max_position_embeddings, ) if truncate: low = math.floor(low) high = math.ceil(high) return max(low, 0.0), min(high, dim - 1.0) def linear_ramp_factor( min_value: float, max_value: float, dim: int, ) -> torch.Tensor: if min_value == max_value: max_value += 0.001 linear_func = ( torch.arange(dim, dtype=torch.float32, device=device) - min_value ) / (max_value - min_value) return torch.clamp(linear_func, 0, 1) pos_freqs = config.rope_theta ** ( torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) low, high = find_correction_range( beta_fast, beta_slow, dim=dim, base=config.rope_theta, max_position_embeddings=original_max_position_embeddings, truncate=truncate, ) inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2) inv_freq = ( inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor ) return inv_freq, float(attention_factor) def _compute_rope_parameters( config: LizzyConfig, device: torch.device, *, seq_len: int | torch.Tensor | None = None, rope_type_override: str | None = None, ) -> tuple[torch.Tensor, float]: rope_scaling = dict(config.rope_scaling or {}) rope_type = rope_type_override if rope_type is None: if not rope_scaling: return _compute_default_rope_parameters(config, device) rope_type = str( rope_scaling.get("rope_type", rope_scaling.get("type", "default")) ) if rope_type == "default": return _compute_default_rope_parameters(config, device) if rope_type == "yarn": return _compute_yarn_rope_parameters(config, device) if not rope_scaling: return _compute_default_rope_parameters(config, device) rope_init_fn = ( ROPE_INIT_FUNCTIONS.get(rope_type) or ROPE_INIT_FUNCTIONS.get("default") ) if rope_init_fn is None: return _compute_default_rope_parameters(config, device) inv_freq, attention_factor = rope_init_fn(config, device, seq_len=seq_len) return inv_freq.to(device=device, dtype=torch.float32), float(attention_factor) def _looks_like_legacy_interval_rope_lizzy(config: LizzyConfig) -> bool: rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or []) if rope_layer_flags and not all(bool(item) for item in rope_layer_flags): return False layer_types = list(getattr(config, "layer_types", None) or []) if layer_types and any(str(item) != "full_attention" for item in layer_types): return False return ( str(getattr(config, "position_embedding_type", "")).lower() == "rope" and not bool(getattr(config, "rope_scaling", None)) and int(getattr(config, "num_hidden_layers", 0) or 0) == 36 and int(getattr(config, "hidden_size", 0) or 0) == 2048 and int(getattr(config, "num_attention_heads", 0) or 0) == 16 and int(getattr(config, "num_key_value_heads", 0) or 0) == 4 and math.isclose( float(getattr(config, "rope_theta", 0.0) or 0.0), 5_000_000.0 ) and not bool(getattr(config, "use_post_attn_norm", False)) and not bool(getattr(config, "use_post_mlp_norm", False)) and not bool(getattr(config, "use_qk_norm", False)) ) def _get_no_rope_layer_interval(config: LizzyConfig) -> int | None: value = getattr(config, "no_rope_layer_interval", None) if value is not None: value = int(value) if value > 0: return value if _looks_like_legacy_interval_rope_lizzy(config): # Backward-compatible fallback for already-uploaded Lizzy # checkpoints that should use NoPE on every 4th layer. return 4 return None def _get_rope_layer_flag(config: LizzyConfig, layer_idx: int) -> bool: rope_enabled = str( getattr(config, "position_embedding_type", "rope") ).lower() == "rope" rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or []) no_rope_layer_interval = _get_no_rope_layer_interval(config) if ( no_rope_layer_interval is not None and ( layer_idx >= len(rope_layer_flags) or not rope_layer_flags or all(bool(item) for item in rope_layer_flags) ) ): return rope_enabled and ((layer_idx + 1) % no_rope_layer_interval != 0) if 0 <= layer_idx < len(rope_layer_flags): return rope_enabled and bool(rope_layer_flags[layer_idx]) return rope_enabled def _get_layer_layout(config: LizzyConfig, layer_idx: int) -> str: layer_layouts = list(getattr(config, "layer_layouts", None) or []) if 0 <= layer_idx < len(layer_layouts): return str(layer_layouts[layer_idx]) if bool(getattr(config, "use_post_attn_norm", False)) or bool( getattr(config, "use_post_mlp_norm", False) ): return "decoder_postnorm" return "decoder_prenorm" def _has_linear_attention(config: LizzyConfig) -> bool: return any( str(layer_type) == "linear_attention" for layer_type in list(getattr(config, "layer_types", None) or []) ) class LizzyHybridDynamicCache: """Cache for Lizzy checkpoints with mixed full and linear attention.""" is_compileable = False def __init__(self, config: LizzyConfig) -> None: super().__init__() self.layer_types = list(config.layer_types) self.transformer_layers = [ idx for idx, layer_type in enumerate(self.layer_types) if layer_type == "full_attention" ] self.last_linear_layer = ( len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") ) self.recurrent_states = [None for _ in range(config.num_hidden_layers)] self.key_cache = [None for _ in range(config.num_hidden_layers)] self.value_cache = [None for _ in range(config.num_hidden_layers)] self.conv_states_q = [None for _ in range(config.num_hidden_layers)] self.conv_states_k = [None for _ in range(config.num_hidden_layers)] self.conv_states_v = [None for _ in range(config.num_hidden_layers)] def __len__(self) -> int: return len(self.layer_types) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: del cache_kwargs if self.key_cache[layer_idx] is None: self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: self.key_cache[layer_idx] = torch.cat( [self.key_cache[layer_idx], key_states], dim=2, ) self.value_cache[layer_idx] = torch.cat( [self.value_cache[layer_idx], value_states], dim=2, ) return self.key_cache[layer_idx], self.value_cache[layer_idx] def reorder_cache(self, beam_idx: torch.LongTensor) -> None: batch_size = beam_idx.shape[0] for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx] is not None: if self.key_cache[layer_idx].shape[0] < batch_size: expand_ratio = ( batch_size // self.key_cache[layer_idx].shape[0] ) self.key_cache[layer_idx] = ( self.key_cache[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) self.value_cache[layer_idx] = ( self.value_cache[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) device = self.key_cache[layer_idx].device self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( 0, beam_idx.to(device), ) self.value_cache[layer_idx] = ( self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) ) if self.conv_states_q[layer_idx] is not None: if self.conv_states_q[layer_idx].shape[0] < batch_size: expand_ratio = ( batch_size // self.conv_states_q[layer_idx].shape[0] ) self.conv_states_q[layer_idx] = ( self.conv_states_q[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) self.conv_states_k[layer_idx] = ( self.conv_states_k[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) self.conv_states_v[layer_idx] = ( self.conv_states_v[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) self.recurrent_states[layer_idx] = ( self.recurrent_states[layer_idx].repeat_interleave( expand_ratio, dim=0, ) ) device = self.conv_states_q[layer_idx].device self.conv_states_q[layer_idx] = ( self.conv_states_q[layer_idx].index_select( 0, beam_idx.to(device), ) ) self.conv_states_k[layer_idx] = ( self.conv_states_k[layer_idx].index_select( 0, beam_idx.to(device), ) ) self.conv_states_v[layer_idx] = ( self.conv_states_v[layer_idx].index_select( 0, beam_idx.to(device), ) ) self.recurrent_states[layer_idx] = ( self.recurrent_states[layer_idx].index_select( 0, beam_idx.to(device), ) ) def get_seq_length(self, layer_idx: int | None = 0) -> int: if not self.transformer_layers: return 0 layer_idx = ( self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx ) if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: return 0 return self.key_cache[layer_idx].shape[-2] def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: del layer_idx kv_offset = 0 past_seen_tokens = self.get_seq_length() kv_length = query_length + past_seen_tokens return kv_length, kv_offset @property def has_previous_state(self) -> bool: # Mirror the upstream contract: once the final linear layer has cached # its conv state, single-token decode can switch to the recurrent path. return self.conv_states_q[self.last_linear_layer] is not None class LizzyHybridRMSNormGated(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward( self, hidden_states: torch.Tensor, gate: torch.Tensor | None = None, ) -> torch.Tensor: if gate is None: msg = "gate is required for gated RMSNorm." raise ValueError(msg) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) hidden_states = self.weight * hidden_states.to(input_dtype) hidden_states = hidden_states * F.silu(gate.to(torch.float32)) return hidden_states.to(input_dtype) class LizzyHybridShortConvolution(nn.Conv1d): def __init__( self, hidden_size: int, kernel_size: int, bias: bool = False, activation: str | None = "silu", ) -> None: super().__init__( in_channels=hidden_size, out_channels=hidden_size, kernel_size=kernel_size, groups=hidden_size, padding=kernel_size - 1, bias=bias, ) self.hidden_size = hidden_size self.conv_kernel_size = kernel_size self.act_fn = ACT2FN[activation] def forward( self, hidden_states: torch.Tensor, cache: torch.Tensor | None = None, use_precomputed: bool = False, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: del kwargs seq_len, dim = hidden_states.shape[-2:] hidden_states = hidden_states.transpose(1, 2) if use_precomputed: if cache is None: msg = "cache is required when use_precomputed=True." raise ValueError(msg) x_with_state = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d( x_with_state, self.weight, self.bias, padding=0, groups=dim, ) conv_state = x_with_state[:, :, 1:] else: out = F.conv1d( hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim, ) out = out[:, :, :seq_len] conv_state = F.pad( hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0), ) out = self.act_fn(out) return out.transpose(1, 2), conv_state def _apply_mask_to_padding_states( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None, ) -> torch.Tensor: # Match the upstream hybrid implementation: silence padded tokens before # the DeltaNet projections so recurrent state does not absorb padding. if ( attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1 ): dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) return hidden_states def _l2norm( x: torch.Tensor, dim: int = -1, eps: float = 1e-6, ) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) return x * inv_norm def _torch_chunk_gated_delta_rule( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, chunk_size: int = 64, initial_state: torch.Tensor | None = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch_size, num_heads, sequence_length, k_head_dim = key.shape v_head_dim = value.shape[-1] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size query = F.pad(query, (0, 0, 0, pad_size)) key = F.pad(key, (0, 0, 0, pad_size)) value = F.pad(value, (0, 0, 0, pad_size)) beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) total_sequence_length = sequence_length + pad_size scale = 1 / (query.shape[-1] ** 0.5) query = query * scale v_beta = value * beta.unsqueeze(-1) k_beta = key * beta.unsqueeze(-1) query, key, value, k_beta, v_beta = [ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) ] g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) mask = torch.triu( torch.ones( chunk_size, chunk_size, dtype=torch.bool, device=query.device, ), diagonal=0, ) g = g.cumsum(dim=-1) decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) for idx in range(1, chunk_size): row = attn[..., idx, :idx].clone() sub = attn[..., :idx, :idx].clone() attn[..., idx, :idx] = row + (row.unsqueeze(-1) * sub).sum(-2) attn = attn + torch.eye( chunk_size, dtype=attn.dtype, device=attn.device, ) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) last_recurrent_state = ( torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) if initial_state is None else initial_state.to(value) ) core_attn_out = torch.zeros_like(value) mask = torch.triu( torch.ones( chunk_size, chunk_size, dtype=torch.bool, device=query.device, ), diagonal=1, ) for idx in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, idx], key[:, :, idx], value[:, :, idx] attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, idx]).masked_fill_( mask, 0, ) v_prime = (k_cumdecay[:, :, idx]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, idx, :, None].exp()) @ last_recurrent_state core_attn_out[:, :, idx] = attn_inter + attn @ v_new last_recurrent_state = ( last_recurrent_state * g[:, :, idx, -1, None, None].exp() + ( k_i * (g[:, :, idx, -1, None] - g[:, :, idx]).exp()[..., None] ).transpose(-1, -2) @ v_new ) if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.reshape( core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1], ) core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state def _torch_recurrent_gated_delta_rule( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, initial_state: torch.Tensor | None, output_final_state: bool, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch_size, num_heads, sequence_length, k_head_dim = key.shape v_head_dim = value.shape[-1] scale = 1 / (query.shape[-1] ** 0.5) query = query * scale core_attn_out = torch.zeros( batch_size, num_heads, sequence_length, v_head_dim, ).to(value) last_recurrent_state = ( torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) if initial_state is None else initial_state.to(value) ) for idx in range(sequence_length): q_t = query[:, :, idx] k_t = key[:, :, idx] v_t = value[:, :, idx] g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) beta_t = beta[:, :, idx].unsqueeze(-1) last_recurrent_state = last_recurrent_state * g_t kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t last_recurrent_state = ( last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) ) core_attn_out[:, :, idx] = ( last_recurrent_state * q_t.unsqueeze(-1) ).sum(dim=-2) if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state class LizzyHybridGatedDeltaNet(nn.Module): def __init__(self, config: LizzyConfig, layer_idx: int) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_v_heads = config.linear_num_value_heads self.num_k_heads = config.linear_num_key_heads self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads self.layer_idx = layer_idx self.conv_kernel_size = config.linear_conv_kernel_dim self.allow_neg_eigval = config.linear_allow_neg_eigval self.eps = config.rms_norm_eps self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) self.a_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) self.b_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) # Step-02 conversion runs on CPU by default, even on GPU nodes. In that # flow Triton-backed FLA kernels will crash as soon as a CPU tensor # reaches them, so the wrapper can force the pure PyTorch fallback for # Hybrid layers via an environment switch. disable_fla_fast_path = os.environ.get( "LIZZY_DISABLE_HYBRID_FLA", "", ).strip().lower() in {"1", "true", "yes", "on"} use_fla_fast_path = ( not disable_fla_fast_path and torch.cuda.is_available() and ShortConvolution is not None and chunk_gated_delta_rule is not None and fused_recurrent_gated_delta_rule is not None and FusedRMSNormGated is not None ) # Keep the fast-path contract when FLA is present, but fall back to a # local implementation so the public Lizzy artifact never depends on # family-specific Transformers remote code. conv1d_class = ( ShortConvolution if use_fla_fast_path else LizzyHybridShortConvolution ) self.q_conv1d = conv1d_class( hidden_size=self.key_dim, kernel_size=self.conv_kernel_size, bias=False, activation="silu", ) self.k_conv1d = conv1d_class( hidden_size=self.key_dim, kernel_size=self.conv_kernel_size, bias=False, activation="silu", ) self.v_conv1d = conv1d_class( hidden_size=self.value_dim, kernel_size=self.conv_kernel_size, bias=False, activation="silu", ) a = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_( config.linear_a_log_min, config.linear_a_log_max, ) self.A_log = nn.Parameter(torch.log(a)) dt = torch.exp( torch.rand(self.num_v_heads) * (math.log(config.linear_dt_max) - math.log(config.linear_dt_min)) + math.log(config.linear_dt_min) ) dt = torch.clamp(dt, min=config.linear_dt_init_floor) inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) self.o_norm = ( LizzyHybridRMSNormGated(self.head_v_dim, eps=1e-5) if not use_fla_fast_path else FusedRMSNormGated( self.head_v_dim, eps=1e-5, device=torch.cuda.current_device(), dtype=( config.dtype if config.dtype is not None else torch.get_default_dtype() ), ) ) self.chunk_gated_delta_rule = ( chunk_gated_delta_rule if use_fla_fast_path else _torch_chunk_gated_delta_rule ) self.recurrent_gated_delta_rule = ( ( fused_recurrent_gated_delta_rule if use_fla_fast_path else _torch_recurrent_gated_delta_rule ) ) def forward( self, hidden_states: torch.Tensor, cache_params: LizzyHybridDynamicCache | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Any, ) -> torch.Tensor: del kwargs hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None use_precomputed = ( use_cache and getattr(cache_params, "has_previous_state", False) and seq_len == 1 ) conv_state_q = ( cache_params.conv_states_q[self.layer_idx] if cache_params else None ) conv_state_k = ( cache_params.conv_states_k[self.layer_idx] if cache_params else None ) conv_state_v = ( cache_params.conv_states_v[self.layer_idx] if cache_params else None ) recurrent_state = ( cache_params.recurrent_states[self.layer_idx] if cache_params else None ) q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) q, new_conv_state_q = self.q_conv1d( q, cache=conv_state_q, use_precomputed=use_precomputed, output_final_state=use_cache, ) k, new_conv_state_k = self.k_conv1d( k, cache=conv_state_k, use_precomputed=use_precomputed, output_final_state=use_cache, ) v, new_conv_state_v = self.v_conv1d( v, cache=conv_state_v, use_precomputed=use_precomputed, output_final_state=use_cache, ) if cache_params is not None: cache_params.conv_states_q[self.layer_idx] = new_conv_state_q cache_params.conv_states_k[self.layer_idx] = new_conv_state_k cache_params.conv_states_v[self.layer_idx] = new_conv_state_v q = q.view(batch_size, seq_len, -1, self.head_k_dim) k = k.view(batch_size, seq_len, -1, self.head_k_dim) v = v.view(batch_size, seq_len, -1, self.head_v_dim) if self.num_v_heads > self.num_k_heads: expand_ratio = self.num_v_heads // self.num_k_heads q = q.repeat_interleave(expand_ratio, dim=2) k = k.repeat_interleave(expand_ratio, dim=2) beta = self.b_proj(hidden_states).sigmoid() if self.allow_neg_eigval: beta = beta * 2.0 g = -self.A_log.float().exp() * F.softplus( self.a_proj(hidden_states).float() + self.dt_bias ) if use_precomputed: output, new_recurrent_state = self.recurrent_gated_delta_rule( q, k, v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) else: output, new_recurrent_state = self.chunk_gated_delta_rule( q, k, v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) if cache_params is not None: cache_params.recurrent_states[self.layer_idx] = new_recurrent_state gate = self.g_proj(hidden_states) output = output.reshape(-1, self.head_v_dim) gate = gate.reshape(-1, self.head_v_dim) output = self.o_norm(output, gate) output = output.reshape(batch_size, seq_len, -1) output = self.o_proj(output) return output class LizzyLinearAttention(nn.Module): def __init__(self, config: LizzyConfig, layer_idx: int) -> None: super().__init__() self.layer_idx = layer_idx self.inner = LizzyHybridGatedDeltaNet(config, layer_idx) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, past_key_value: Cache | None = None, use_cache: bool = False, output_attentions: bool = False, **kwargs: Any, ) -> tuple[ torch.Tensor, Cache | None, torch.Tensor | None, ]: del kwargs, output_attentions output = self.inner( hidden_states=hidden_states, cache_params=( past_key_value if _is_cache_object(past_key_value) else None ), attention_mask=attention_mask, ) present = past_key_value if use_cache else None return output, present, None class LizzyAttention(nn.Module): def __init__(self, config: LizzyConfig, layer_idx: int) -> None: super().__init__() self.is_causal = True self.config = config self.layer_idx = layer_idx self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.head_dim = config.head_dim self.hidden_size = config.hidden_size self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.position_embedding_type = config.position_embedding_type self.layer_type = ( str(config.layer_types[layer_idx]) if layer_idx < len(config.layer_types) else "full_attention" ) self.use_rope = _get_rope_layer_flag(config, layer_idx) self._rope_type_override = str( dict(config.rope_type_overrides or {}).get(self.layer_type) or "" ) or None if ( self._rope_type_override is None and self.layer_type == "sliding_attention" and bool(config.rope_scaling) and config.use_post_attn_norm and config.use_post_mlp_norm and config.use_qk_norm and any(str(item) == "full_attention" for item in config.layer_types) ): self._rope_type_override = "default" self.sliding_window = None if self.layer_type == "sliding_attention": self.sliding_window = config.sliding_window q_dim = self.num_heads * self.head_dim kv_dim = self.num_key_value_heads * self.head_dim self.q_proj = nn.Linear( config.hidden_size, q_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( config.hidden_size, kv_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( config.hidden_size, kv_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( q_dim, config.hidden_size, bias=config.attention_bias, ) self.q_norm = ( _make_norm(config.qk_norm_type, q_dim, config.norm_eps, has_bias=False) if config.use_qk_norm else None ) self.k_norm = ( _make_norm(config.qk_norm_type, kv_dim, config.norm_eps, has_bias=False) if config.use_qk_norm else None ) self._rope_requires_runtime_update = False if self.use_rope: rope_scaling = dict(config.rope_scaling or {}) rope_type = self._rope_type_override or str( rope_scaling.get("rope_type", rope_scaling.get("type", "default")) ) self._rope_requires_runtime_update = rope_type == "dynamic" if self._rope_requires_runtime_update: self.register_buffer("_rope_inv_freq", None, persistent=False) self.register_buffer( "_rope_attention_factor", None, persistent=False, ) else: inv_freq, attention_factor = _compute_rope_parameters( config, device=torch.device("cpu"), seq_len=config.max_position_embeddings, rope_type_override=self._rope_type_override, ) self.register_buffer("_rope_inv_freq", inv_freq, persistent=False) self.register_buffer( "_rope_attention_factor", torch.tensor(float(attention_factor), dtype=torch.float32), persistent=False, ) else: self.register_buffer("_rope_inv_freq", None, persistent=False) self.register_buffer("_rope_attention_factor", None, persistent=False) def _build_rope( self, position_ids: torch.Tensor, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: if not self.use_rope: msg = "RoPE requested but rope buffer is not initialized." raise RuntimeError(msg) inv_freq = self._rope_inv_freq attention_factor_tensor = self._rope_attention_factor if ( inv_freq is None or attention_factor_tensor is None or self._rope_requires_runtime_update ): # Keep the sequence-length hint as a tensor so TorchDynamo/vLLM # can trace this path without requiring capture_scalar_outputs. # When low-memory loading leaves the non-persistent cache unset, # rebuild from config for this forward only instead of mutating # buffers inside the compiled graph. seq_len = ( torch.max(position_ids) + 1 if position_ids.numel() > 0 else None ) inv_freq, attention_factor = _compute_rope_parameters( self.config, device=device, seq_len=seq_len, rope_type_override=self._rope_type_override, ) attention_factor_tensor = torch.tensor( float(attention_factor), device=device, dtype=torch.float32, ) else: inv_freq = inv_freq.to(device=device) attention_factor_tensor = attention_factor_tensor.to( device=device, dtype=torch.float32, ) # Mirror the upstream HF decoder-only rotary path closely here. # The matmul-based construction is slightly more numerically stable # than the generic einsum formulation for strict parity probes. inv_freq_expanded = ( inv_freq[None, :, None] .to(device=device, dtype=torch.float32) .expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].to(torch.float32) angles = torch.matmul( inv_freq_expanded, position_ids_expanded, ).transpose(1, 2) angles = torch.cat((angles, angles), dim=-1) cos = angles.cos().unsqueeze(1) * attention_factor_tensor sin = angles.sin().unsqueeze(1) * attention_factor_tensor cos = cos.to(dtype) sin = sin.to(dtype) return cos, sin def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None, cache_position: torch.Tensor | None = None, use_cache: bool = False, output_attentions: bool = False, **kwargs: Any, ) -> tuple[ torch.Tensor, Cache | tuple[torch.Tensor, torch.Tensor] | None, torch.Tensor | None, ]: batch_size, q_len, _ = hidden_states.shape cache_position = _normalize_cache_position(cache_position) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) if self.q_norm is not None: query_states = self.q_norm(query_states) if self.k_norm is not None: key_states = self.k_norm(key_states) query_states = query_states.view( batch_size, q_len, self.num_heads, self.head_dim, ) query_states = query_states.transpose(1, 2) key_states = key_states.view( batch_size, q_len, self.num_key_value_heads, self.head_dim, ) key_states = key_states.transpose(1, 2) value_states = value_states.view( batch_size, q_len, self.num_key_value_heads, self.head_dim, ) value_states = value_states.transpose(1, 2) if self.use_rope: if position_ids is None: msg = "position_ids are required for rope attention." raise ValueError(msg) cos, sin = self._build_rope( position_ids, hidden_states.device, query_states.dtype, ) query_states, key_states = _apply_rotary_pos_emb( query_states, key_states, cos, sin, ) if _is_cache_object(past_key_value): if use_cache: key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs={"cache_position": cache_position}, ) present_key_value = past_key_value elif self.layer_idx < len(past_key_value): past_key, past_value = past_key_value[self.layer_idx] if past_key is not None and past_value is not None: key_states = torch.cat([past_key, key_states], dim=2) value_states = torch.cat([past_value, value_states], dim=2) present_key_value = None else: present_key_value = None elif past_key_value is not None: past_key, past_value = past_key_value key_states = torch.cat([past_key, key_states], dim=2) value_states = torch.cat([past_value, value_states], dim=2) present_key_value = (key_states, value_states) if use_cache else None else: present_key_value = (key_states, value_states) if use_cache else None attention_interface = None attn_impl = getattr(self.config, "_attn_implementation", "eager") if attn_impl == "flex_attention" and self.head_dim < 16: attn_impl = "sdpa" if attn_impl != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS.get(attn_impl) if attention_interface is not None: attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.contiguous() else: if self.num_key_value_heads != self.num_heads: key_states = key_states.repeat_interleave( self.num_key_value_groups, dim=1, ) value_states = value_states.repeat_interleave( self.num_key_value_groups, dim=1, ) attn_weights = torch.matmul( query_states, key_states.transpose(-1, -2), ) * self.scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) attn_weights = attn_weights.to(query_states.dtype) attn_weights = F.dropout( attn_weights, p=self.attention_dropout if self.training else 0.0, training=self.training, ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, present_key_value, attn_weights def _refresh_attention_rope_buffers(module: nn.Module) -> None: """Rebuild non-persistent RoPE buffers after checkpoint load.""" for child in module.modules(): if not isinstance(child, LizzyAttention): continue should_use_rope = _get_rope_layer_flag(child.config, child.layer_idx) child.use_rope = should_use_rope if not should_use_rope: child._rope_requires_runtime_update = False child._rope_inv_freq = None child._rope_attention_factor = None continue rope_scaling = dict(child.config.rope_scaling or {}) rope_type = child._rope_type_override or str( rope_scaling.get("rope_type", rope_scaling.get("type", "default")) ) child._rope_requires_runtime_update = rope_type == "dynamic" if child._rope_requires_runtime_update: child._rope_inv_freq = None child._rope_attention_factor = None continue # These buffers are derived from config rather than serialized weights. # Recompute them after load so low-memory materialization cannot leave # stale or uninitialized rotary state behind. inv_freq, attention_factor = _compute_rope_parameters( child.config, device=torch.device("cpu"), seq_len=child.config.max_position_embeddings, rope_type_override=child._rope_type_override, ) child._rope_inv_freq = inv_freq child._rope_attention_factor = torch.tensor( float(attention_factor), dtype=torch.float32, ) class LizzyMLP(nn.Module): def __init__(self, config: LizzyConfig) -> None: super().__init__() self.config = config self.act = ACT2FN[config.hidden_act] self.gate_proj = ( nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias, ) if config.mlp_type == "gated" else None ) self.up_proj = nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias, ) self.down_proj = nn.Linear( config.intermediate_size, config.hidden_size, bias=config.mlp_bias, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.gate_proj is None and self.config.mlp_type == "gated": msg = "Missing gated MLP projection layers." raise RuntimeError(msg) if self.config.mlp_type == "gated": if self.gate_proj is None: msg = "Missing gated MLP projection layers." raise RuntimeError(msg) return self.down_proj(self.act( self.gate_proj(hidden_states)) * self.up_proj(hidden_states) ) return self.down_proj(self.act(self.up_proj(hidden_states))) class LizzyDecoderLayer(nn.Module): def __init__(self, config: LizzyConfig, layer_idx: int) -> None: super().__init__() self.layer_type = ( str(config.layer_types[layer_idx]) if layer_idx < len(config.layer_types) else "full_attention" ) self.layer_layout = _get_layer_layout(config, layer_idx) self.self_attn = ( LizzyAttention(config, layer_idx) if self.layer_type != "linear_attention" else None ) self.linear_attn = ( LizzyLinearAttention(config, layer_idx) if self.layer_type == "linear_attention" else None ) self.mlp = LizzyMLP(config) self.pre_attn_norm = ( _make_norm( config.norm_type, config.hidden_size, config.norm_eps, has_bias=config.norm_has_bias, ) if self.layer_layout == "decoder_prenorm" else None ) self.pre_mlp_norm = ( _make_norm( config.norm_type, config.hidden_size, config.norm_eps, has_bias=config.norm_has_bias, ) if self.layer_layout == "decoder_prenorm" else None ) self.post_attn_norm = ( _make_norm( config.norm_type, config.hidden_size, config.norm_eps, has_bias=config.norm_has_bias, ) if self.layer_layout == "decoder_postnorm" else None ) self.post_mlp_norm = ( _make_norm( config.norm_type, config.hidden_size, config.norm_eps, has_bias=config.norm_has_bias ) if self.layer_layout == "decoder_postnorm" else None ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None, cache_position: torch.Tensor | None = None, use_cache: bool = False, output_attentions: bool = False, **kwargs: Any, ) -> tuple[ torch.Tensor, Cache | tuple[torch.Tensor, torch.Tensor] | None, torch.Tensor | None, ]: residual = hidden_states attn_inputs = ( self.pre_attn_norm(hidden_states) if self.pre_attn_norm is not None else hidden_states ) if self.linear_attn is not None: attn_output, present_key_value, attn_weights = self.linear_attn( attn_inputs, attention_mask=attention_mask, past_key_value=( past_key_value if _is_cache_object(past_key_value) else None ), use_cache=use_cache, output_attentions=output_attentions, **kwargs, ) else: assert self.self_attn is not None attn_output, present_key_value, attn_weights = self.self_attn( attn_inputs, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, cache_position=cache_position, use_cache=use_cache, output_attentions=output_attentions, **kwargs, ) if self.post_attn_norm is not None: attn_output = self.post_attn_norm(attn_output) hidden_states = residual + attn_output residual = hidden_states mlp_inputs = ( self.pre_mlp_norm(hidden_states) if self.pre_mlp_norm is not None else hidden_states ) mlp_output = self.mlp(mlp_inputs) if self.post_mlp_norm is not None: mlp_output = self.post_mlp_norm(mlp_output) hidden_states = residual + mlp_output return hidden_states, present_key_value, attn_weights class LizzyPreTrainedModel(PreTrainedModel): config_class = LizzyConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LizzyDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, (LizzyRMSNorm, nn.LayerNorm)): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | os.PathLike[str] | None, *model_args: Any, **kwargs: Any, ) -> "LizzyPreTrainedModel": model = cast( "LizzyPreTrainedModel", super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, ), ) _refresh_attention_rope_buffers(model) if hasattr(model, "lm_head") and hasattr(model, "model"): tied_weights_keys = getattr(type(model), "_tied_weights_keys", None) if isinstance(tied_weights_keys, dict) and tied_weights_keys: model._tied_weights_keys = dict(tied_weights_keys) else: model._tied_weights_keys = { "lm_head.weight": "model.embed_tokens.weight", } model._tp_plan = {"lm_head": "colwise_rep"} model._pp_plan = {"lm_head": (["hidden_states"], ["logits"])} return model def load_state_dict( # type: ignore[override] self, state_dict: dict[str, torch.Tensor], strict: bool = True, assign: bool = False, ) -> Any: remapped_state_dict: dict[str, torch.Tensor] = {} for key, value in state_dict.items(): remapped_key = key if ".mlp.fc_in." in key: remapped_key = key.replace(".mlp.fc_in.", ".mlp.up_proj.") elif ".mlp.fc_out." in key: remapped_key = key.replace(".mlp.fc_out.", ".mlp.down_proj.") existing = remapped_state_dict.get(remapped_key) if existing is not None and not torch.equal(existing, value): msg = ( f"Conflicting legacy Lizzy MLP tensors" f" for key: {remapped_key}" ) raise ValueError(msg) remapped_state_dict[remapped_key] = value load_result = super().load_state_dict( remapped_state_dict, strict=strict, assign=assign, ) # RoPE buffers are intentionally non-persistent, so refresh them after # weight loading instead of trusting constructor-time allocations. _refresh_attention_rope_buffers(self) return load_result class LizzyModel(LizzyPreTrainedModel): def __init__(self, config: LizzyConfig) -> None: super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx, ) self.embed_positions = ( nn.Embedding(config.max_position_embeddings, config.hidden_size) if config.position_embedding_type == "absolute" else None ) self.layers = nn.ModuleList( LizzyDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ) self.norm = _make_norm( config.norm_type, config.hidden_size, config.norm_eps, has_bias=config.norm_has_bias, ) self.embd_dropout = nn.Dropout(config.embd_dropout) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: self.embed_tokens = value def _build_attention_mask( self, attention_mask: torch.Tensor | None, *, batch_size: int, q_len: int, kv_len: int, kv_offset: int, cache_position: torch.Tensor, device: torch.device, dtype: torch.dtype, sliding_window: int | None = None, ) -> torch.Tensor: kv_len = ( int(kv_len.item()) if isinstance(kv_len, torch.Tensor) else int(kv_len) ) kv_offset = ( int(kv_offset.item()) if isinstance(kv_offset, torch.Tensor) else int(kv_offset) ) min_value = torch.finfo(dtype).min source_positions = cache_position.to(device=device).view(-1, 1) target_positions = torch.arange( kv_offset, kv_offset + kv_len, device=device, ).unsqueeze(0) causal = torch.zeros((q_len, kv_len), dtype=dtype, device=device) causal = causal.masked_fill(target_positions > source_positions, min_value) if sliding_window is not None: lower_bound = source_positions - int(sliding_window) + 1 causal = causal.masked_fill(target_positions < lower_bound, min_value) causal = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) if attention_mask is None: return causal if attention_mask.dim() != 2: msg = "attention_mask must be 2D [batch, sequence]." raise ValueError(msg) if attention_mask.shape[1] < kv_len: pad = torch.ones( (attention_mask.shape[0], kv_len - attention_mask.shape[1]), dtype=attention_mask.dtype, device=attention_mask.device, ) attention_mask = torch.cat([pad, attention_mask], dim=1) elif attention_mask.shape[1] > kv_len: attention_mask = attention_mask[:, -kv_len:] expanded = attention_mask[:, None, None, :].to(device=device) padding = (expanded == 0).to(dtype) * min_value return causal + padding def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | tuple[ tuple[torch.Tensor, torch.Tensor], ... ] | None = None, inputs_embeds: torch.FloatTensor | None = None, cache_position: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs: Any, ) -> BaseModelOutputWithPast | tuple[Any, ...]: if (input_ids is None) == (inputs_embeds is None): msg = "Exactly one of input_ids or inputs_embeds must be provided." raise ValueError(msg) output_attentions = ( bool(output_attentions) if output_attentions is not None else False ) output_hidden_states = ( bool(output_hidden_states) if output_hidden_states is not None else False ) use_cache = ( bool(use_cache) if use_cache is not None else bool(self.config.use_cache) ) return_dict = bool(return_dict) if return_dict is not None else True if inputs_embeds is None: hidden_states = self.embed_tokens(input_ids) batch_size, seq_len = input_ids.shape else: hidden_states = inputs_embeds batch_size, seq_len, _ = inputs_embeds.shape cache_object = ( past_key_values if _is_cache_object(past_key_values) else None ) if use_cache and _has_linear_attention(self.config): # Transformers 5.4 seeds `generate()` with an empty DynamicCache # for standard causal decoders. Hybrid Lizzy checkpoints need the # mixed cache below instead, because linear-attention layers read # DeltaNet convolution/recurrent state during the prefill pass. if cache_object is not None and not isinstance( cache_object, LizzyHybridDynamicCache, ): if int(cache_object.get_seq_length()) > 0: msg = ( "Hybrid Lizzy checkpoints require " "LizzyHybridDynamicCache once generation cache " "state is populated." ) raise ValueError(msg) cache_object = LizzyHybridDynamicCache(config=self.config) past_key_values = cache_object if use_cache and cache_object is None and past_key_values is None: if _has_linear_attention(self.config): # Linear-attention checkpoints need a mixed cache that can hold # both KV tensors and recurrent DeltaNet state. cache_object = LizzyHybridDynamicCache(config=self.config) else: cache_object = DynamicCache() past_key_values = cache_object if cache_object is not None: past_length = int(cache_object.get_seq_length()) else: past_length = _legacy_cache_length(past_key_values) cache_position = _normalize_cache_position(cache_position) if cache_position is None: cache_position = torch.arange( past_length, past_length + seq_len, dtype=torch.long, device=hidden_states.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) if self.embed_positions is not None: hidden_states = hidden_states + self.embed_positions(position_ids) hidden_states = self.embd_dropout(hidden_states) if self.training and self.gradient_checkpointing: use_cache = False layer_types = list(self.config.layer_types) if not layer_types: layer_types = ["full_attention"] * len(self.layers) has_linear_attention = any( str(layer_type) == "linear_attention" for layer_type in layer_types ) _attn_impl = getattr(self.config, "_attn_implementation", "eager") if has_linear_attention and isinstance(attention_mask, dict): linear_attention_mask = attention_mask.get("linear_attention") else: linear_attention_mask = attention_mask if ( has_linear_attention and cache_object is not None and getattr(cache_object, "has_previous_state", False) ): linear_attention_mask = None elif ( has_linear_attention and attention_mask is not None and not isinstance(attention_mask, dict) and torch.all(attention_mask == 1) ): linear_attention_mask = None if ( _attn_impl == "flash_attention_2" and not isinstance(attention_mask, dict) ): # Flash attention handles causal masking (via is_causal) and # padding (via 2D mask) natively; skip building a 4D mask. attention_mask_mapping = { lt: attention_mask for lt in dict.fromkeys(layer_types) if lt != "linear_attention" } elif _attn_impl == "sdpa" and attention_mask is None: attention_mask_mapping = {} for layer_type in dict.fromkeys(layer_types): if layer_type == "linear_attention": continue if layer_type == "full_attention": # Match upstream decoder-only HF models: when SDPA sees # plain causal full attention with no padding mask to # preserve, let it use its native is_causal fast-path # instead of forcing an explicit 4D bias tensor. attention_mask_mapping[layer_type] = None continue layer_idx = layer_types.index(layer_type) if cache_object is not None: kv_len, kv_offset = cache_object.get_mask_sizes( seq_len, layer_idx, ) else: kv_len = past_length + seq_len kv_offset = 0 attention_mask_mapping[layer_type] = self._build_attention_mask( attention_mask, batch_size=batch_size, q_len=seq_len, kv_len=kv_len, kv_offset=kv_offset, cache_position=cache_position, device=hidden_states.device, dtype=hidden_states.dtype, sliding_window=( self.config.sliding_window if layer_type == "sliding_attention" else None ), ) elif isinstance(attention_mask, dict): attention_mask_mapping = { key: value for key, value in attention_mask.items() if key != "linear_attention" } else: attention_mask_mapping: dict[str, torch.Tensor] = {} for layer_type in dict.fromkeys(layer_types): if layer_type == "linear_attention": continue layer_idx = layer_types.index(layer_type) if cache_object is not None: kv_len, kv_offset = cache_object.get_mask_sizes( seq_len, layer_idx, ) else: kv_len = past_length + seq_len kv_offset = 0 attention_mask_mapping[layer_type] = self._build_attention_mask( attention_mask, batch_size=batch_size, q_len=seq_len, kv_len=kv_len, kv_offset=kv_offset, cache_position=cache_position, device=hidden_states.device, dtype=hidden_states.dtype, sliding_window=( self.config.sliding_window if layer_type == "sliding_attention" else None ), ) all_hidden_states = [] if output_hidden_states else None all_attentions = [] if output_attentions else None next_cache = ( cache_object if cache_object is not None else ([] if use_cache else None) ) gradient_checkpointing_func = getattr( self, "_gradient_checkpointing_func", checkpoint, ) for idx, layer in enumerate(self.layers): if output_hidden_states and all_hidden_states is not None: all_hidden_states.append(hidden_states) layer_type = ( layer_types[idx] if idx < len(layer_types) else "full_attention" ) if layer_type == "linear_attention": layer_attention_mask = linear_attention_mask else: layer_attention_mask = attention_mask_mapping[layer_type] if cache_object is not None: layer_past: Cache | tuple[ torch.Tensor, torch.Tensor ] | None = cache_object elif past_key_values is not None: layer_past = past_key_values[idx] if layer_past is not None and layer_past[0] is None: layer_past = None else: layer_past = None if self.training and self.gradient_checkpointing: def custom_forward(hidden_states: torch.Tensor) -> Any: layer_outputs = layer( hidden_states, attention_mask=layer_attention_mask, position_ids=position_ids, past_key_value=None, cache_position=cache_position, use_cache=False, output_attentions=output_attentions, **kwargs, ) if output_attentions: return layer_outputs[0], layer_outputs[2] return layer_outputs[0] checkpointed_outputs = gradient_checkpointing_func( custom_forward, hidden_states, ) if output_attentions: hidden_states, attn_weights = checkpointed_outputs else: hidden_states = checkpointed_outputs attn_weights = None present = None else: hidden_states, present, attn_weights = layer( hidden_states, attention_mask=layer_attention_mask, position_ids=position_ids, past_key_value=layer_past, cache_position=cache_position, use_cache=use_cache, output_attentions=output_attentions, **kwargs, ) if use_cache and next_cache is not None and cache_object is None: next_cache.append(present) if output_attentions and all_attentions is not None: all_attentions.append(attn_weights) hidden_states = self.norm(hidden_states) if output_hidden_states and all_hidden_states is not None: all_hidden_states.append(hidden_states) past_key_values_output: Cache | tuple[ tuple[torch.Tensor, torch.Tensor], ... ] | None = None if use_cache and next_cache is not None: if cache_object is not None: past_key_values_output = cache_object else: past_key_values_output = tuple(next_cache) if not return_dict: output: tuple[Any, ...] = (hidden_states,) if past_key_values_output is not None: output = output + (past_key_values_output,) if output_hidden_states and all_hidden_states is not None: output = output + (tuple(all_hidden_states),) if output_attentions and all_attentions is not None: output = output + (tuple(all_attentions),) return output return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values_output, hidden_states=( tuple(all_hidden_states) if all_hidden_states is not None else None ), attentions=( tuple(all_attentions) if all_attentions is not None else None ), ) class LizzyForCausalLM(LizzyPreTrainedModel, GenerationMixin): config_class = LizzyConfig # Transformers 5.4 expects an expanded target->source mapping here rather than # the older list-based shorthand. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: LizzyConfig) -> None: super().__init__(config) self.model = LizzyModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.model.get_input_embeddings() def set_input_embeddings(self, value: nn.Embedding) -> None: self.model.set_input_embeddings(value) def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Cache | tuple[ tuple[torch.Tensor, torch.Tensor], ... ] | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Any, ) -> dict[str, Any]: past_length = 0 if past_key_values is not None: if _is_cache_object(past_key_values): past_length = int(past_key_values.get_seq_length()) else: past_length = _legacy_cache_length(past_key_values) cache_position = _normalize_cache_position(cache_position) if cache_position is None: if past_key_values is not None: new_tokens = input_ids.shape[1] - past_length if new_tokens <= 0: new_tokens = 1 cache_position = torch.arange( past_length, past_length + new_tokens, device=input_ids.device, ) else: cache_position = torch.arange( input_ids.shape[1], device=input_ids.device, ) if past_key_values is not None: input_ids = input_ids[:, -cache_position.shape[0] :] if attention_mask is not None: attn_mask_idx = (past_length + input_ids.shape[1]) attention_mask = attention_mask[:, -attn_mask_idx :] if inputs_embeds is not None and past_key_values is None: model_inputs: dict[str, Any] = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { "past_key_values": past_key_values, "attention_mask": attention_mask, "cache_position": cache_position, "use_cache": kwargs.get("use_cache", self.config.use_cache), }, ) return model_inputs def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | tuple[ tuple[torch.Tensor, torch.Tensor], ... ] | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, cache_position: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Any, ) -> CausalLMOutputWithPast | tuple[Any, ...]: # HF eval loaders call `forward()` without an explicit return_dict, # so local Lizzy exports must normalize the optional flag first. return_dict = bool(return_dict) if return_dict is not None else True outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) if labels is not None: full_logits = self.lm_head(hidden_states) logits = full_logits[:, slice_indices, :] else: full_logits = None logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: shift_logits = full_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) if not return_dict: output = (logits,) + outputs[1:] if loss is not None: output = (loss,) + output return output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )