Text Generation
Transformers
Safetensors
English
lizzy
lizzy-7b
flwrlabs
british-english
conversational
custom_code
Instructions to use flwrlabs/Lizzy-7B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use flwrlabs/Lizzy-7B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="flwrlabs/Lizzy-7B", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("flwrlabs/Lizzy-7B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use flwrlabs/Lizzy-7B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "flwrlabs/Lizzy-7B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "flwrlabs/Lizzy-7B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/flwrlabs/Lizzy-7B
- SGLang
How to use flwrlabs/Lizzy-7B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "flwrlabs/Lizzy-7B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "flwrlabs/Lizzy-7B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "flwrlabs/Lizzy-7B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "flwrlabs/Lizzy-7B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use flwrlabs/Lizzy-7B with Docker Model Runner:
docker model run hf.co/flwrlabs/Lizzy-7B
| from __future__ import annotations | |
| import math | |
| import os | |
| from typing import Any, cast | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.generation import GenerationMixin | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPast, | |
| CausalLMOutputWithPast, | |
| ) | |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | |
| try: | |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS | |
| except ImportError: | |
| ROPE_INIT_FUNCTIONS = {} | |
| try: | |
| from fla.modules import FusedRMSNormGated, ShortConvolution | |
| from fla.ops.gated_delta_rule import ( | |
| chunk_gated_delta_rule, | |
| fused_recurrent_gated_delta_rule, | |
| ) | |
| except ImportError: | |
| chunk_gated_delta_rule = None | |
| fused_recurrent_gated_delta_rule = None | |
| FusedRMSNormGated = None | |
| ShortConvolution = None | |
| from .configuration_lizzy import LizzyConfig | |
| class LizzyRMSNorm(nn.Module): | |
| def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt( | |
| variance + self.variance_epsilon | |
| ) | |
| return self.weight * hidden_states.to(input_dtype) | |
| def _make_norm( | |
| norm_type: str, | |
| hidden_size: int, | |
| eps: float, | |
| *, | |
| has_bias: bool, | |
| ) -> nn.Module: | |
| if norm_type == "rmsnorm": | |
| return LizzyRMSNorm(hidden_size, eps=eps) | |
| if norm_type == "layernorm": | |
| return nn.LayerNorm( | |
| hidden_size, | |
| eps=eps, | |
| elementwise_affine=True, | |
| bias=has_bias, | |
| ) | |
| msg = f"Unsupported norm_type: {norm_type}" | |
| raise ValueError(msg) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_rotary_pos_emb( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| q_embed = (q * cos) + (_rotate_half(q) * sin) | |
| k_embed = (k * cos) + (_rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def _legacy_cache_length( | |
| past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None, | |
| ) -> int: | |
| if ( | |
| isinstance(past_key_values, tuple) | |
| and len(past_key_values) > 0 | |
| and past_key_values[0] is not None | |
| and past_key_values[0][0] is not None | |
| ): | |
| return int(past_key_values[0][0].shape[2]) | |
| return 0 | |
| def _normalize_cache_position( | |
| cache_position: torch.Tensor | None, | |
| ) -> torch.Tensor | None: | |
| if cache_position is None: | |
| return None | |
| if cache_position.dim() == 0: | |
| return cache_position.view(1) | |
| if cache_position.dim() > 1: | |
| return cache_position[0] | |
| return cache_position | |
| def _is_cache_object(value: Any) -> bool: | |
| return isinstance(value, Cache) or isinstance(value, LizzyHybridDynamicCache) | |
| def _compute_default_rope_parameters( | |
| config: LizzyConfig, | |
| device: torch.device, | |
| ) -> tuple[torch.Tensor, float]: | |
| inv_freq = 1.0 / ( | |
| config.rope_theta | |
| ** ( | |
| torch.arange(0, config.head_dim, 2, device=device, dtype=torch.float32) | |
| / config.head_dim | |
| ) | |
| ) | |
| return inv_freq, 1.0 | |
| def _compute_yarn_rope_parameters( | |
| config: LizzyConfig, | |
| device: torch.device, | |
| ) -> tuple[torch.Tensor, float]: | |
| rope_scaling = dict(config.rope_scaling or {}) | |
| factor = float(rope_scaling["factor"]) | |
| attention_factor = rope_scaling.get("attention_factor") | |
| mscale = rope_scaling.get("mscale") | |
| mscale_all_dim = rope_scaling.get("mscale_all_dim") | |
| original_max_position_embeddings = int( | |
| rope_scaling.get("original_max_position_embeddings") | |
| or config.max_position_embeddings | |
| ) | |
| def get_mscale(scale: float, mscale_value: float = 1.0) -> float: | |
| if scale <= 1.0: | |
| return 1.0 | |
| return 0.1 * mscale_value * math.log(scale) + 1.0 | |
| if attention_factor is None: | |
| if mscale is not None and mscale_all_dim is not None: | |
| attention_factor = float( | |
| get_mscale(factor, float(mscale)) | |
| / get_mscale(factor, float(mscale_all_dim)) | |
| ) | |
| else: | |
| attention_factor = get_mscale(factor) | |
| beta_fast = float(rope_scaling.get("beta_fast") or 32.0) | |
| beta_slow = float(rope_scaling.get("beta_slow") or 1.0) | |
| truncate = bool(rope_scaling.get("truncate", True)) | |
| dim = config.head_dim | |
| def find_correction_dim( | |
| num_rotations: float, | |
| *, | |
| dim: int, | |
| base: float, | |
| max_position_embeddings: int, | |
| ) -> float: | |
| return ( | |
| dim | |
| * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) | |
| / (2 * math.log(base)) | |
| ) | |
| def find_correction_range( | |
| low_rot: float, | |
| high_rot: float, | |
| *, | |
| dim: int, | |
| base: float, | |
| max_position_embeddings: int, | |
| truncate: bool, | |
| ) -> tuple[float, float]: | |
| low = find_correction_dim( | |
| low_rot, | |
| dim=dim, | |
| base=base, | |
| max_position_embeddings=max_position_embeddings, | |
| ) | |
| high = find_correction_dim( | |
| high_rot, | |
| dim=dim, | |
| base=base, | |
| max_position_embeddings=max_position_embeddings, | |
| ) | |
| if truncate: | |
| low = math.floor(low) | |
| high = math.ceil(high) | |
| return max(low, 0.0), min(high, dim - 1.0) | |
| def linear_ramp_factor( | |
| min_value: float, | |
| max_value: float, | |
| dim: int, | |
| ) -> torch.Tensor: | |
| if min_value == max_value: | |
| max_value += 0.001 | |
| linear_func = ( | |
| torch.arange(dim, dtype=torch.float32, device=device) - min_value | |
| ) / (max_value - min_value) | |
| return torch.clamp(linear_func, 0, 1) | |
| pos_freqs = config.rope_theta ** ( | |
| torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim | |
| ) | |
| inv_freq_extrapolation = 1.0 / pos_freqs | |
| inv_freq_interpolation = 1.0 / (factor * pos_freqs) | |
| low, high = find_correction_range( | |
| beta_fast, | |
| beta_slow, | |
| dim=dim, | |
| base=config.rope_theta, | |
| max_position_embeddings=original_max_position_embeddings, | |
| truncate=truncate, | |
| ) | |
| inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2) | |
| inv_freq = ( | |
| inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) | |
| + inv_freq_extrapolation * inv_freq_extrapolation_factor | |
| ) | |
| return inv_freq, float(attention_factor) | |
| def _compute_rope_parameters( | |
| config: LizzyConfig, | |
| device: torch.device, | |
| *, | |
| seq_len: int | torch.Tensor | None = None, | |
| rope_type_override: str | None = None, | |
| ) -> tuple[torch.Tensor, float]: | |
| rope_scaling = dict(config.rope_scaling or {}) | |
| rope_type = rope_type_override | |
| if rope_type is None: | |
| if not rope_scaling: | |
| return _compute_default_rope_parameters(config, device) | |
| rope_type = str( | |
| rope_scaling.get("rope_type", rope_scaling.get("type", "default")) | |
| ) | |
| if rope_type == "default": | |
| return _compute_default_rope_parameters(config, device) | |
| if rope_type == "yarn": | |
| return _compute_yarn_rope_parameters(config, device) | |
| if not rope_scaling: | |
| return _compute_default_rope_parameters(config, device) | |
| rope_init_fn = ( | |
| ROPE_INIT_FUNCTIONS.get(rope_type) or ROPE_INIT_FUNCTIONS.get("default") | |
| ) | |
| if rope_init_fn is None: | |
| return _compute_default_rope_parameters(config, device) | |
| inv_freq, attention_factor = rope_init_fn(config, device, seq_len=seq_len) | |
| return inv_freq.to(device=device, dtype=torch.float32), float(attention_factor) | |
| def _looks_like_legacy_interval_rope_lizzy(config: LizzyConfig) -> bool: | |
| rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or []) | |
| if rope_layer_flags and not all(bool(item) for item in rope_layer_flags): | |
| return False | |
| layer_types = list(getattr(config, "layer_types", None) or []) | |
| if layer_types and any(str(item) != "full_attention" for item in layer_types): | |
| return False | |
| return ( | |
| str(getattr(config, "position_embedding_type", "")).lower() == "rope" | |
| and not bool(getattr(config, "rope_scaling", None)) | |
| and int(getattr(config, "num_hidden_layers", 0) or 0) == 36 | |
| and int(getattr(config, "hidden_size", 0) or 0) == 2048 | |
| and int(getattr(config, "num_attention_heads", 0) or 0) == 16 | |
| and int(getattr(config, "num_key_value_heads", 0) or 0) == 4 | |
| and math.isclose( | |
| float(getattr(config, "rope_theta", 0.0) or 0.0), 5_000_000.0 | |
| ) | |
| and not bool(getattr(config, "use_post_attn_norm", False)) | |
| and not bool(getattr(config, "use_post_mlp_norm", False)) | |
| and not bool(getattr(config, "use_qk_norm", False)) | |
| ) | |
| def _get_no_rope_layer_interval(config: LizzyConfig) -> int | None: | |
| value = getattr(config, "no_rope_layer_interval", None) | |
| if value is not None: | |
| value = int(value) | |
| if value > 0: | |
| return value | |
| if _looks_like_legacy_interval_rope_lizzy(config): | |
| # Backward-compatible fallback for already-uploaded Lizzy | |
| # checkpoints that should use NoPE on every 4th layer. | |
| return 4 | |
| return None | |
| def _get_rope_layer_flag(config: LizzyConfig, layer_idx: int) -> bool: | |
| rope_enabled = str( | |
| getattr(config, "position_embedding_type", "rope") | |
| ).lower() == "rope" | |
| rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or []) | |
| no_rope_layer_interval = _get_no_rope_layer_interval(config) | |
| if ( | |
| no_rope_layer_interval is not None | |
| and ( | |
| layer_idx >= len(rope_layer_flags) | |
| or not rope_layer_flags | |
| or all(bool(item) for item in rope_layer_flags) | |
| ) | |
| ): | |
| return rope_enabled and ((layer_idx + 1) % no_rope_layer_interval != 0) | |
| if 0 <= layer_idx < len(rope_layer_flags): | |
| return rope_enabled and bool(rope_layer_flags[layer_idx]) | |
| return rope_enabled | |
| def _get_layer_layout(config: LizzyConfig, layer_idx: int) -> str: | |
| layer_layouts = list(getattr(config, "layer_layouts", None) or []) | |
| if 0 <= layer_idx < len(layer_layouts): | |
| return str(layer_layouts[layer_idx]) | |
| if bool(getattr(config, "use_post_attn_norm", False)) or bool( | |
| getattr(config, "use_post_mlp_norm", False) | |
| ): | |
| return "decoder_postnorm" | |
| return "decoder_prenorm" | |
| def _has_linear_attention(config: LizzyConfig) -> bool: | |
| return any( | |
| str(layer_type) == "linear_attention" | |
| for layer_type in list(getattr(config, "layer_types", None) or []) | |
| ) | |
| class LizzyHybridDynamicCache: | |
| """Cache for Lizzy checkpoints with mixed full and linear attention.""" | |
| is_compileable = False | |
| def __init__(self, config: LizzyConfig) -> None: | |
| super().__init__() | |
| self.layer_types = list(config.layer_types) | |
| self.transformer_layers = [ | |
| idx | |
| for idx, layer_type in enumerate(self.layer_types) | |
| if layer_type == "full_attention" | |
| ] | |
| self.last_linear_layer = ( | |
| len(self.layer_types) | |
| - 1 | |
| - self.layer_types[::-1].index("linear_attention") | |
| ) | |
| self.recurrent_states = [None for _ in range(config.num_hidden_layers)] | |
| self.key_cache = [None for _ in range(config.num_hidden_layers)] | |
| self.value_cache = [None for _ in range(config.num_hidden_layers)] | |
| self.conv_states_q = [None for _ in range(config.num_hidden_layers)] | |
| self.conv_states_k = [None for _ in range(config.num_hidden_layers)] | |
| self.conv_states_v = [None for _ in range(config.num_hidden_layers)] | |
| def __len__(self) -> int: | |
| return len(self.layer_types) | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: dict[str, Any] | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| del cache_kwargs | |
| if self.key_cache[layer_idx] is None: | |
| self.key_cache[layer_idx] = key_states | |
| self.value_cache[layer_idx] = value_states | |
| else: | |
| self.key_cache[layer_idx] = torch.cat( | |
| [self.key_cache[layer_idx], key_states], | |
| dim=2, | |
| ) | |
| self.value_cache[layer_idx] = torch.cat( | |
| [self.value_cache[layer_idx], value_states], | |
| dim=2, | |
| ) | |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| batch_size = beam_idx.shape[0] | |
| for layer_idx in range(len(self.key_cache)): | |
| if self.key_cache[layer_idx] is not None: | |
| if self.key_cache[layer_idx].shape[0] < batch_size: | |
| expand_ratio = ( | |
| batch_size // self.key_cache[layer_idx].shape[0] | |
| ) | |
| self.key_cache[layer_idx] = ( | |
| self.key_cache[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| self.value_cache[layer_idx] = ( | |
| self.value_cache[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| device = self.key_cache[layer_idx].device | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( | |
| 0, | |
| beam_idx.to(device), | |
| ) | |
| self.value_cache[layer_idx] = ( | |
| self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | |
| ) | |
| if self.conv_states_q[layer_idx] is not None: | |
| if self.conv_states_q[layer_idx].shape[0] < batch_size: | |
| expand_ratio = ( | |
| batch_size // self.conv_states_q[layer_idx].shape[0] | |
| ) | |
| self.conv_states_q[layer_idx] = ( | |
| self.conv_states_q[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| self.conv_states_k[layer_idx] = ( | |
| self.conv_states_k[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| self.conv_states_v[layer_idx] = ( | |
| self.conv_states_v[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| self.recurrent_states[layer_idx] = ( | |
| self.recurrent_states[layer_idx].repeat_interleave( | |
| expand_ratio, dim=0, | |
| ) | |
| ) | |
| device = self.conv_states_q[layer_idx].device | |
| self.conv_states_q[layer_idx] = ( | |
| self.conv_states_q[layer_idx].index_select( | |
| 0, | |
| beam_idx.to(device), | |
| ) | |
| ) | |
| self.conv_states_k[layer_idx] = ( | |
| self.conv_states_k[layer_idx].index_select( | |
| 0, | |
| beam_idx.to(device), | |
| ) | |
| ) | |
| self.conv_states_v[layer_idx] = ( | |
| self.conv_states_v[layer_idx].index_select( | |
| 0, | |
| beam_idx.to(device), | |
| ) | |
| ) | |
| self.recurrent_states[layer_idx] = ( | |
| self.recurrent_states[layer_idx].index_select( | |
| 0, | |
| beam_idx.to(device), | |
| ) | |
| ) | |
| def get_seq_length(self, layer_idx: int | None = 0) -> int: | |
| if not self.transformer_layers: | |
| return 0 | |
| layer_idx = ( | |
| self.transformer_layers[0] | |
| if layer_idx not in self.transformer_layers | |
| else layer_idx | |
| ) | |
| if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: | |
| return 0 | |
| return self.key_cache[layer_idx].shape[-2] | |
| def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: | |
| del layer_idx | |
| kv_offset = 0 | |
| past_seen_tokens = self.get_seq_length() | |
| kv_length = query_length + past_seen_tokens | |
| return kv_length, kv_offset | |
| def has_previous_state(self) -> bool: | |
| # Mirror the upstream contract: once the final linear layer has cached | |
| # its conv state, single-token decode can switch to the recurrent path. | |
| return self.conv_states_q[self.last_linear_layer] is not None | |
| class LizzyHybridRMSNormGated(nn.Module): | |
| def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| gate: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| if gate is None: | |
| msg = "gate is required for gated RMSNorm." | |
| raise ValueError(msg) | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt( | |
| variance + self.variance_epsilon | |
| ) | |
| hidden_states = self.weight * hidden_states.to(input_dtype) | |
| hidden_states = hidden_states * F.silu(gate.to(torch.float32)) | |
| return hidden_states.to(input_dtype) | |
| class LizzyHybridShortConvolution(nn.Conv1d): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| kernel_size: int, | |
| bias: bool = False, | |
| activation: str | None = "silu", | |
| ) -> None: | |
| super().__init__( | |
| in_channels=hidden_size, | |
| out_channels=hidden_size, | |
| kernel_size=kernel_size, | |
| groups=hidden_size, | |
| padding=kernel_size - 1, | |
| bias=bias, | |
| ) | |
| self.hidden_size = hidden_size | |
| self.conv_kernel_size = kernel_size | |
| self.act_fn = ACT2FN[activation] | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache: torch.Tensor | None = None, | |
| use_precomputed: bool = False, | |
| **kwargs: Any, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| del kwargs | |
| seq_len, dim = hidden_states.shape[-2:] | |
| hidden_states = hidden_states.transpose(1, 2) | |
| if use_precomputed: | |
| if cache is None: | |
| msg = "cache is required when use_precomputed=True." | |
| raise ValueError(msg) | |
| x_with_state = torch.cat([cache, hidden_states], dim=-1) | |
| out = F.conv1d( | |
| x_with_state, | |
| self.weight, | |
| self.bias, | |
| padding=0, | |
| groups=dim, | |
| ) | |
| conv_state = x_with_state[:, :, 1:] | |
| else: | |
| out = F.conv1d( | |
| hidden_states, | |
| self.weight, | |
| self.bias, | |
| padding=self.conv_kernel_size - 1, | |
| groups=dim, | |
| ) | |
| out = out[:, :, :seq_len] | |
| conv_state = F.pad( | |
| hidden_states, | |
| (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0), | |
| ) | |
| out = self.act_fn(out) | |
| return out.transpose(1, 2), conv_state | |
| def _apply_mask_to_padding_states( | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| ) -> torch.Tensor: | |
| # Match the upstream hybrid implementation: silence padded tokens before | |
| # the DeltaNet projections so recurrent state does not absorb padding. | |
| if ( | |
| attention_mask is not None | |
| and attention_mask.shape[1] > 1 | |
| and attention_mask.shape[0] > 1 | |
| ): | |
| dtype = hidden_states.dtype | |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) | |
| return hidden_states | |
| def _l2norm( | |
| x: torch.Tensor, | |
| dim: int = -1, | |
| eps: float = 1e-6, | |
| ) -> torch.Tensor: | |
| inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) | |
| return x * inv_norm | |
| def _torch_chunk_gated_delta_rule( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| g: torch.Tensor, | |
| beta: torch.Tensor, | |
| chunk_size: int = 64, | |
| initial_state: torch.Tensor | None = None, | |
| output_final_state: bool = False, | |
| use_qk_l2norm_in_kernel: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = _l2norm(query, dim=-1, eps=1e-6) | |
| key = _l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) | |
| for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size | |
| query = F.pad(query, (0, 0, 0, pad_size)) | |
| key = F.pad(key, (0, 0, 0, pad_size)) | |
| value = F.pad(value, (0, 0, 0, pad_size)) | |
| beta = F.pad(beta, (0, pad_size)) | |
| g = F.pad(g, (0, pad_size)) | |
| total_sequence_length = sequence_length + pad_size | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| v_beta = value * beta.unsqueeze(-1) | |
| k_beta = key * beta.unsqueeze(-1) | |
| query, key, value, k_beta, v_beta = [ | |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) | |
| for x in (query, key, value, k_beta, v_beta) | |
| ] | |
| g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) | |
| mask = torch.triu( | |
| torch.ones( | |
| chunk_size, | |
| chunk_size, | |
| dtype=torch.bool, | |
| device=query.device, | |
| ), | |
| diagonal=0, | |
| ) | |
| g = g.cumsum(dim=-1) | |
| decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() | |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) | |
| for idx in range(1, chunk_size): | |
| row = attn[..., idx, :idx].clone() | |
| sub = attn[..., :idx, :idx].clone() | |
| attn[..., idx, :idx] = row + (row.unsqueeze(-1) * sub).sum(-2) | |
| attn = attn + torch.eye( | |
| chunk_size, | |
| dtype=attn.dtype, | |
| device=attn.device, | |
| ) | |
| value = attn @ v_beta | |
| k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| core_attn_out = torch.zeros_like(value) | |
| mask = torch.triu( | |
| torch.ones( | |
| chunk_size, | |
| chunk_size, | |
| dtype=torch.bool, | |
| device=query.device, | |
| ), | |
| diagonal=1, | |
| ) | |
| for idx in range(0, total_sequence_length // chunk_size): | |
| q_i, k_i, v_i = query[:, :, idx], key[:, :, idx], value[:, :, idx] | |
| attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, idx]).masked_fill_( | |
| mask, | |
| 0, | |
| ) | |
| v_prime = (k_cumdecay[:, :, idx]) @ last_recurrent_state | |
| v_new = v_i - v_prime | |
| attn_inter = (q_i * g[:, :, idx, :, None].exp()) @ last_recurrent_state | |
| core_attn_out[:, :, idx] = attn_inter + attn @ v_new | |
| last_recurrent_state = ( | |
| last_recurrent_state * g[:, :, idx, -1, None, None].exp() | |
| + ( | |
| k_i | |
| * (g[:, :, idx, -1, None] - g[:, :, idx]).exp()[..., None] | |
| ).transpose(-1, -2) | |
| ) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.reshape( | |
| core_attn_out.shape[0], | |
| core_attn_out.shape[1], | |
| -1, | |
| core_attn_out.shape[-1], | |
| ) | |
| core_attn_out = core_attn_out[:, :, :sequence_length] | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| def _torch_recurrent_gated_delta_rule( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| g: torch.Tensor, | |
| beta: torch.Tensor, | |
| initial_state: torch.Tensor | None, | |
| output_final_state: bool, | |
| use_qk_l2norm_in_kernel: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = _l2norm(query, dim=-1, eps=1e-6) | |
| key = _l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) | |
| for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| core_attn_out = torch.zeros( | |
| batch_size, | |
| num_heads, | |
| sequence_length, | |
| v_head_dim, | |
| ).to(value) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| for idx in range(sequence_length): | |
| q_t = query[:, :, idx] | |
| k_t = key[:, :, idx] | |
| v_t = value[:, :, idx] | |
| g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) | |
| beta_t = beta[:, :, idx].unsqueeze(-1) | |
| last_recurrent_state = last_recurrent_state * g_t | |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | |
| delta = (v_t - kv_mem) * beta_t | |
| last_recurrent_state = ( | |
| last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | |
| ) | |
| core_attn_out[:, :, idx] = ( | |
| last_recurrent_state * q_t.unsqueeze(-1) | |
| ).sum(dim=-2) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| class LizzyHybridGatedDeltaNet(nn.Module): | |
| def __init__(self, config: LizzyConfig, layer_idx: int) -> None: | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_v_heads = config.linear_num_value_heads | |
| self.num_k_heads = config.linear_num_key_heads | |
| self.head_k_dim = config.linear_key_head_dim | |
| self.head_v_dim = config.linear_value_head_dim | |
| self.key_dim = self.head_k_dim * self.num_k_heads | |
| self.value_dim = self.head_v_dim * self.num_v_heads | |
| self.layer_idx = layer_idx | |
| self.conv_kernel_size = config.linear_conv_kernel_dim | |
| self.allow_neg_eigval = config.linear_allow_neg_eigval | |
| self.eps = config.rms_norm_eps | |
| self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) | |
| self.a_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| self.b_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) | |
| self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | |
| # Step-02 conversion runs on CPU by default, even on GPU nodes. In that | |
| # flow Triton-backed FLA kernels will crash as soon as a CPU tensor | |
| # reaches them, so the wrapper can force the pure PyTorch fallback for | |
| # Hybrid layers via an environment switch. | |
| disable_fla_fast_path = os.environ.get( | |
| "LIZZY_DISABLE_HYBRID_FLA", | |
| "", | |
| ).strip().lower() in {"1", "true", "yes", "on"} | |
| use_fla_fast_path = ( | |
| not disable_fla_fast_path | |
| and | |
| torch.cuda.is_available() | |
| and ShortConvolution is not None | |
| and chunk_gated_delta_rule is not None | |
| and fused_recurrent_gated_delta_rule is not None | |
| and FusedRMSNormGated is not None | |
| ) | |
| # Keep the fast-path contract when FLA is present, but fall back to a | |
| # local implementation so the public Lizzy artifact never depends on | |
| # family-specific Transformers remote code. | |
| conv1d_class = ( | |
| ShortConvolution if use_fla_fast_path else LizzyHybridShortConvolution | |
| ) | |
| self.q_conv1d = conv1d_class( | |
| hidden_size=self.key_dim, | |
| kernel_size=self.conv_kernel_size, | |
| bias=False, | |
| activation="silu", | |
| ) | |
| self.k_conv1d = conv1d_class( | |
| hidden_size=self.key_dim, | |
| kernel_size=self.conv_kernel_size, | |
| bias=False, | |
| activation="silu", | |
| ) | |
| self.v_conv1d = conv1d_class( | |
| hidden_size=self.value_dim, | |
| kernel_size=self.conv_kernel_size, | |
| bias=False, | |
| activation="silu", | |
| ) | |
| a = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_( | |
| config.linear_a_log_min, | |
| config.linear_a_log_max, | |
| ) | |
| self.A_log = nn.Parameter(torch.log(a)) | |
| dt = torch.exp( | |
| torch.rand(self.num_v_heads) | |
| * (math.log(config.linear_dt_max) - math.log(config.linear_dt_min)) | |
| + math.log(config.linear_dt_min) | |
| ) | |
| dt = torch.clamp(dt, min=config.linear_dt_init_floor) | |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) | |
| self.dt_bias = nn.Parameter(inv_dt) | |
| self.o_norm = ( | |
| LizzyHybridRMSNormGated(self.head_v_dim, eps=1e-5) | |
| if not use_fla_fast_path | |
| else FusedRMSNormGated( | |
| self.head_v_dim, | |
| eps=1e-5, | |
| device=torch.cuda.current_device(), | |
| dtype=( | |
| config.dtype | |
| if config.dtype is not None | |
| else torch.get_default_dtype() | |
| ), | |
| ) | |
| ) | |
| self.chunk_gated_delta_rule = ( | |
| chunk_gated_delta_rule | |
| if use_fla_fast_path | |
| else _torch_chunk_gated_delta_rule | |
| ) | |
| self.recurrent_gated_delta_rule = ( | |
| ( | |
| fused_recurrent_gated_delta_rule | |
| if use_fla_fast_path | |
| else _torch_recurrent_gated_delta_rule | |
| ) | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: LizzyHybridDynamicCache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| **kwargs: Any, | |
| ) -> torch.Tensor: | |
| del kwargs | |
| hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask) | |
| batch_size, seq_len, _ = hidden_states.shape | |
| use_cache = cache_params is not None | |
| use_precomputed = ( | |
| use_cache | |
| and getattr(cache_params, "has_previous_state", False) | |
| and seq_len == 1 | |
| ) | |
| conv_state_q = ( | |
| cache_params.conv_states_q[self.layer_idx] if cache_params else None | |
| ) | |
| conv_state_k = ( | |
| cache_params.conv_states_k[self.layer_idx] if cache_params else None | |
| ) | |
| conv_state_v = ( | |
| cache_params.conv_states_v[self.layer_idx] if cache_params else None | |
| ) | |
| recurrent_state = ( | |
| cache_params.recurrent_states[self.layer_idx] if cache_params else None | |
| ) | |
| q = self.q_proj(hidden_states) | |
| k = self.k_proj(hidden_states) | |
| v = self.v_proj(hidden_states) | |
| q, new_conv_state_q = self.q_conv1d( | |
| q, | |
| cache=conv_state_q, | |
| use_precomputed=use_precomputed, | |
| output_final_state=use_cache, | |
| ) | |
| k, new_conv_state_k = self.k_conv1d( | |
| k, | |
| cache=conv_state_k, | |
| use_precomputed=use_precomputed, | |
| output_final_state=use_cache, | |
| ) | |
| v, new_conv_state_v = self.v_conv1d( | |
| v, | |
| cache=conv_state_v, | |
| use_precomputed=use_precomputed, | |
| output_final_state=use_cache, | |
| ) | |
| if cache_params is not None: | |
| cache_params.conv_states_q[self.layer_idx] = new_conv_state_q | |
| cache_params.conv_states_k[self.layer_idx] = new_conv_state_k | |
| cache_params.conv_states_v[self.layer_idx] = new_conv_state_v | |
| q = q.view(batch_size, seq_len, -1, self.head_k_dim) | |
| k = k.view(batch_size, seq_len, -1, self.head_k_dim) | |
| v = v.view(batch_size, seq_len, -1, self.head_v_dim) | |
| if self.num_v_heads > self.num_k_heads: | |
| expand_ratio = self.num_v_heads // self.num_k_heads | |
| q = q.repeat_interleave(expand_ratio, dim=2) | |
| k = k.repeat_interleave(expand_ratio, dim=2) | |
| beta = self.b_proj(hidden_states).sigmoid() | |
| if self.allow_neg_eigval: | |
| beta = beta * 2.0 | |
| g = -self.A_log.float().exp() * F.softplus( | |
| self.a_proj(hidden_states).float() + self.dt_bias | |
| ) | |
| if use_precomputed: | |
| output, new_recurrent_state = self.recurrent_gated_delta_rule( | |
| q, | |
| k, | |
| v, | |
| g=g, | |
| beta=beta, | |
| initial_state=recurrent_state, | |
| output_final_state=use_cache, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| else: | |
| output, new_recurrent_state = self.chunk_gated_delta_rule( | |
| q, | |
| k, | |
| v, | |
| g=g, | |
| beta=beta, | |
| initial_state=recurrent_state, | |
| output_final_state=use_cache, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| if cache_params is not None: | |
| cache_params.recurrent_states[self.layer_idx] = new_recurrent_state | |
| gate = self.g_proj(hidden_states) | |
| output = output.reshape(-1, self.head_v_dim) | |
| gate = gate.reshape(-1, self.head_v_dim) | |
| output = self.o_norm(output, gate) | |
| output = output.reshape(batch_size, seq_len, -1) | |
| output = self.o_proj(output) | |
| return output | |
| class LizzyLinearAttention(nn.Module): | |
| def __init__(self, config: LizzyConfig, layer_idx: int) -> None: | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.inner = LizzyHybridGatedDeltaNet(config, layer_idx) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| past_key_value: Cache | None = None, | |
| use_cache: bool = False, | |
| output_attentions: bool = False, | |
| **kwargs: Any, | |
| ) -> tuple[ | |
| torch.Tensor, | |
| Cache | None, | |
| torch.Tensor | None, | |
| ]: | |
| del kwargs, output_attentions | |
| output = self.inner( | |
| hidden_states=hidden_states, | |
| cache_params=( | |
| past_key_value if _is_cache_object(past_key_value) else None | |
| ), | |
| attention_mask=attention_mask, | |
| ) | |
| present = past_key_value if use_cache else None | |
| return output, present, None | |
| class LizzyAttention(nn.Module): | |
| def __init__(self, config: LizzyConfig, layer_idx: int) -> None: | |
| super().__init__() | |
| self.is_causal = True | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.num_heads = config.num_attention_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.head_dim = config.head_dim | |
| self.hidden_size = config.hidden_size | |
| self.scaling = self.head_dim**-0.5 | |
| self.attention_dropout = config.attention_dropout | |
| self.position_embedding_type = config.position_embedding_type | |
| self.layer_type = ( | |
| str(config.layer_types[layer_idx]) | |
| if layer_idx < len(config.layer_types) | |
| else "full_attention" | |
| ) | |
| self.use_rope = _get_rope_layer_flag(config, layer_idx) | |
| self._rope_type_override = str( | |
| dict(config.rope_type_overrides or {}).get(self.layer_type) or "" | |
| ) or None | |
| if ( | |
| self._rope_type_override is None | |
| and self.layer_type == "sliding_attention" | |
| and bool(config.rope_scaling) | |
| and config.use_post_attn_norm | |
| and config.use_post_mlp_norm | |
| and config.use_qk_norm | |
| and any(str(item) == "full_attention" for item in config.layer_types) | |
| ): | |
| self._rope_type_override = "default" | |
| self.sliding_window = None | |
| if self.layer_type == "sliding_attention": | |
| self.sliding_window = config.sliding_window | |
| q_dim = self.num_heads * self.head_dim | |
| kv_dim = self.num_key_value_heads * self.head_dim | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, | |
| q_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, | |
| kv_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, | |
| kv_dim, | |
| bias=config.attention_bias, | |
| ) | |
| self.o_proj = nn.Linear( | |
| q_dim, | |
| config.hidden_size, | |
| bias=config.attention_bias, | |
| ) | |
| self.q_norm = ( | |
| _make_norm(config.qk_norm_type, q_dim, config.norm_eps, has_bias=False) | |
| if config.use_qk_norm | |
| else None | |
| ) | |
| self.k_norm = ( | |
| _make_norm(config.qk_norm_type, kv_dim, config.norm_eps, has_bias=False) | |
| if config.use_qk_norm | |
| else None | |
| ) | |
| self._rope_requires_runtime_update = False | |
| if self.use_rope: | |
| rope_scaling = dict(config.rope_scaling or {}) | |
| rope_type = self._rope_type_override or str( | |
| rope_scaling.get("rope_type", rope_scaling.get("type", "default")) | |
| ) | |
| self._rope_requires_runtime_update = rope_type == "dynamic" | |
| if self._rope_requires_runtime_update: | |
| self.register_buffer("_rope_inv_freq", None, persistent=False) | |
| self.register_buffer( | |
| "_rope_attention_factor", None, persistent=False, | |
| ) | |
| else: | |
| inv_freq, attention_factor = _compute_rope_parameters( | |
| config, | |
| device=torch.device("cpu"), | |
| seq_len=config.max_position_embeddings, | |
| rope_type_override=self._rope_type_override, | |
| ) | |
| self.register_buffer("_rope_inv_freq", inv_freq, persistent=False) | |
| self.register_buffer( | |
| "_rope_attention_factor", | |
| torch.tensor(float(attention_factor), dtype=torch.float32), | |
| persistent=False, | |
| ) | |
| else: | |
| self.register_buffer("_rope_inv_freq", None, persistent=False) | |
| self.register_buffer("_rope_attention_factor", None, persistent=False) | |
| def _build_rope( | |
| self, | |
| position_ids: torch.Tensor, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if not self.use_rope: | |
| msg = "RoPE requested but rope buffer is not initialized." | |
| raise RuntimeError(msg) | |
| inv_freq = self._rope_inv_freq | |
| attention_factor_tensor = self._rope_attention_factor | |
| if ( | |
| inv_freq is None | |
| or attention_factor_tensor is None | |
| or self._rope_requires_runtime_update | |
| ): | |
| # Keep the sequence-length hint as a tensor so TorchDynamo/vLLM | |
| # can trace this path without requiring capture_scalar_outputs. | |
| # When low-memory loading leaves the non-persistent cache unset, | |
| # rebuild from config for this forward only instead of mutating | |
| # buffers inside the compiled graph. | |
| seq_len = ( | |
| torch.max(position_ids) + 1 if position_ids.numel() > 0 else None | |
| ) | |
| inv_freq, attention_factor = _compute_rope_parameters( | |
| self.config, | |
| device=device, | |
| seq_len=seq_len, | |
| rope_type_override=self._rope_type_override, | |
| ) | |
| attention_factor_tensor = torch.tensor( | |
| float(attention_factor), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| else: | |
| inv_freq = inv_freq.to(device=device) | |
| attention_factor_tensor = attention_factor_tensor.to( | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| # Mirror the upstream HF decoder-only rotary path closely here. | |
| # The matmul-based construction is slightly more numerically stable | |
| # than the generic einsum formulation for strict parity probes. | |
| inv_freq_expanded = ( | |
| inv_freq[None, :, None] | |
| .to(device=device, dtype=torch.float32) | |
| .expand(position_ids.shape[0], -1, 1) | |
| ) | |
| position_ids_expanded = position_ids[:, None, :].to(torch.float32) | |
| angles = torch.matmul( | |
| inv_freq_expanded, | |
| position_ids_expanded, | |
| ).transpose(1, 2) | |
| angles = torch.cat((angles, angles), dim=-1) | |
| cos = angles.cos().unsqueeze(1) * attention_factor_tensor | |
| sin = angles.sin().unsqueeze(1) * attention_factor_tensor | |
| cos = cos.to(dtype) | |
| sin = sin.to(dtype) | |
| return cos, sin | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.Tensor | None = None, | |
| past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None, | |
| cache_position: torch.Tensor | None = None, | |
| use_cache: bool = False, | |
| output_attentions: bool = False, | |
| **kwargs: Any, | |
| ) -> tuple[ | |
| torch.Tensor, | |
| Cache | tuple[torch.Tensor, torch.Tensor] | None, | |
| torch.Tensor | None, | |
| ]: | |
| batch_size, q_len, _ = hidden_states.shape | |
| cache_position = _normalize_cache_position(cache_position) | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| if self.q_norm is not None: | |
| query_states = self.q_norm(query_states) | |
| if self.k_norm is not None: | |
| key_states = self.k_norm(key_states) | |
| query_states = query_states.view( | |
| batch_size, q_len, self.num_heads, self.head_dim, | |
| ) | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.view( | |
| batch_size, | |
| q_len, | |
| self.num_key_value_heads, | |
| self.head_dim, | |
| ) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = value_states.view( | |
| batch_size, | |
| q_len, | |
| self.num_key_value_heads, | |
| self.head_dim, | |
| ) | |
| value_states = value_states.transpose(1, 2) | |
| if self.use_rope: | |
| if position_ids is None: | |
| msg = "position_ids are required for rope attention." | |
| raise ValueError(msg) | |
| cos, sin = self._build_rope( | |
| position_ids, hidden_states.device, query_states.dtype, | |
| ) | |
| query_states, key_states = _apply_rotary_pos_emb( | |
| query_states, | |
| key_states, | |
| cos, | |
| sin, | |
| ) | |
| if _is_cache_object(past_key_value): | |
| if use_cache: | |
| key_states, value_states = past_key_value.update( | |
| key_states, | |
| value_states, | |
| self.layer_idx, | |
| cache_kwargs={"cache_position": cache_position}, | |
| ) | |
| present_key_value = past_key_value | |
| elif self.layer_idx < len(past_key_value): | |
| past_key, past_value = past_key_value[self.layer_idx] | |
| if past_key is not None and past_value is not None: | |
| key_states = torch.cat([past_key, key_states], dim=2) | |
| value_states = torch.cat([past_value, value_states], dim=2) | |
| present_key_value = None | |
| else: | |
| present_key_value = None | |
| elif past_key_value is not None: | |
| past_key, past_value = past_key_value | |
| key_states = torch.cat([past_key, key_states], dim=2) | |
| value_states = torch.cat([past_value, value_states], dim=2) | |
| present_key_value = (key_states, value_states) if use_cache else None | |
| else: | |
| present_key_value = (key_states, value_states) if use_cache else None | |
| attention_interface = None | |
| attn_impl = getattr(self.config, "_attn_implementation", "eager") | |
| if attn_impl == "flex_attention" and self.head_dim < 16: | |
| attn_impl = "sdpa" | |
| if attn_impl != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS.get(attn_impl) | |
| if attention_interface is not None: | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| scaling=self.scaling, | |
| sliding_window=self.sliding_window, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.contiguous() | |
| else: | |
| if self.num_key_value_heads != self.num_heads: | |
| key_states = key_states.repeat_interleave( | |
| self.num_key_value_groups, dim=1, | |
| ) | |
| value_states = value_states.repeat_interleave( | |
| self.num_key_value_groups, dim=1, | |
| ) | |
| attn_weights = torch.matmul( | |
| query_states, | |
| key_states.transpose(-1, -2), | |
| ) * self.scaling | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) | |
| attn_weights = attn_weights.to(query_states.dtype) | |
| attn_weights = F.dropout( | |
| attn_weights, | |
| p=self.attention_dropout if self.training else 0.0, | |
| training=self.training, | |
| ) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(batch_size, q_len, -1).contiguous() | |
| attn_output = self.o_proj(attn_output) | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, present_key_value, attn_weights | |
| def _refresh_attention_rope_buffers(module: nn.Module) -> None: | |
| """Rebuild non-persistent RoPE buffers after checkpoint load.""" | |
| for child in module.modules(): | |
| if not isinstance(child, LizzyAttention): | |
| continue | |
| should_use_rope = _get_rope_layer_flag(child.config, child.layer_idx) | |
| child.use_rope = should_use_rope | |
| if not should_use_rope: | |
| child._rope_requires_runtime_update = False | |
| child._rope_inv_freq = None | |
| child._rope_attention_factor = None | |
| continue | |
| rope_scaling = dict(child.config.rope_scaling or {}) | |
| rope_type = child._rope_type_override or str( | |
| rope_scaling.get("rope_type", rope_scaling.get("type", "default")) | |
| ) | |
| child._rope_requires_runtime_update = rope_type == "dynamic" | |
| if child._rope_requires_runtime_update: | |
| child._rope_inv_freq = None | |
| child._rope_attention_factor = None | |
| continue | |
| # These buffers are derived from config rather than serialized weights. | |
| # Recompute them after load so low-memory materialization cannot leave | |
| # stale or uninitialized rotary state behind. | |
| inv_freq, attention_factor = _compute_rope_parameters( | |
| child.config, | |
| device=torch.device("cpu"), | |
| seq_len=child.config.max_position_embeddings, | |
| rope_type_override=child._rope_type_override, | |
| ) | |
| child._rope_inv_freq = inv_freq | |
| child._rope_attention_factor = torch.tensor( | |
| float(attention_factor), | |
| dtype=torch.float32, | |
| ) | |
| class LizzyMLP(nn.Module): | |
| def __init__(self, config: LizzyConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.act = ACT2FN[config.hidden_act] | |
| self.gate_proj = ( | |
| nn.Linear( | |
| config.hidden_size, | |
| config.intermediate_size, | |
| bias=config.mlp_bias, | |
| ) | |
| if config.mlp_type == "gated" | |
| else None | |
| ) | |
| self.up_proj = nn.Linear( | |
| config.hidden_size, | |
| config.intermediate_size, | |
| bias=config.mlp_bias, | |
| ) | |
| self.down_proj = nn.Linear( | |
| config.intermediate_size, | |
| config.hidden_size, | |
| bias=config.mlp_bias, | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| if self.gate_proj is None and self.config.mlp_type == "gated": | |
| msg = "Missing gated MLP projection layers." | |
| raise RuntimeError(msg) | |
| if self.config.mlp_type == "gated": | |
| if self.gate_proj is None: | |
| msg = "Missing gated MLP projection layers." | |
| raise RuntimeError(msg) | |
| return self.down_proj(self.act( | |
| self.gate_proj(hidden_states)) * self.up_proj(hidden_states) | |
| ) | |
| return self.down_proj(self.act(self.up_proj(hidden_states))) | |
| class LizzyDecoderLayer(nn.Module): | |
| def __init__(self, config: LizzyConfig, layer_idx: int) -> None: | |
| super().__init__() | |
| self.layer_type = ( | |
| str(config.layer_types[layer_idx]) | |
| if layer_idx < len(config.layer_types) | |
| else "full_attention" | |
| ) | |
| self.layer_layout = _get_layer_layout(config, layer_idx) | |
| self.self_attn = ( | |
| LizzyAttention(config, layer_idx) | |
| if self.layer_type != "linear_attention" | |
| else None | |
| ) | |
| self.linear_attn = ( | |
| LizzyLinearAttention(config, layer_idx) | |
| if self.layer_type == "linear_attention" | |
| else None | |
| ) | |
| self.mlp = LizzyMLP(config) | |
| self.pre_attn_norm = ( | |
| _make_norm( | |
| config.norm_type, | |
| config.hidden_size, | |
| config.norm_eps, | |
| has_bias=config.norm_has_bias, | |
| ) | |
| if self.layer_layout == "decoder_prenorm" | |
| else None | |
| ) | |
| self.pre_mlp_norm = ( | |
| _make_norm( | |
| config.norm_type, | |
| config.hidden_size, | |
| config.norm_eps, | |
| has_bias=config.norm_has_bias, | |
| ) | |
| if self.layer_layout == "decoder_prenorm" | |
| else None | |
| ) | |
| self.post_attn_norm = ( | |
| _make_norm( | |
| config.norm_type, | |
| config.hidden_size, | |
| config.norm_eps, | |
| has_bias=config.norm_has_bias, | |
| ) | |
| if self.layer_layout == "decoder_postnorm" | |
| else None | |
| ) | |
| self.post_mlp_norm = ( | |
| _make_norm( | |
| config.norm_type, | |
| config.hidden_size, | |
| config.norm_eps, | |
| has_bias=config.norm_has_bias | |
| ) | |
| if self.layer_layout == "decoder_postnorm" | |
| else None | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.Tensor | None = None, | |
| past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None, | |
| cache_position: torch.Tensor | None = None, | |
| use_cache: bool = False, | |
| output_attentions: bool = False, | |
| **kwargs: Any, | |
| ) -> tuple[ | |
| torch.Tensor, | |
| Cache | tuple[torch.Tensor, torch.Tensor] | None, | |
| torch.Tensor | None, | |
| ]: | |
| residual = hidden_states | |
| attn_inputs = ( | |
| self.pre_attn_norm(hidden_states) | |
| if self.pre_attn_norm is not None | |
| else hidden_states | |
| ) | |
| if self.linear_attn is not None: | |
| attn_output, present_key_value, attn_weights = self.linear_attn( | |
| attn_inputs, | |
| attention_mask=attention_mask, | |
| past_key_value=( | |
| past_key_value if _is_cache_object(past_key_value) else None | |
| ), | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| else: | |
| assert self.self_attn is not None | |
| attn_output, present_key_value, attn_weights = self.self_attn( | |
| attn_inputs, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| cache_position=cache_position, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| if self.post_attn_norm is not None: | |
| attn_output = self.post_attn_norm(attn_output) | |
| hidden_states = residual + attn_output | |
| residual = hidden_states | |
| mlp_inputs = ( | |
| self.pre_mlp_norm(hidden_states) | |
| if self.pre_mlp_norm is not None | |
| else hidden_states | |
| ) | |
| mlp_output = self.mlp(mlp_inputs) | |
| if self.post_mlp_norm is not None: | |
| mlp_output = self.post_mlp_norm(mlp_output) | |
| hidden_states = residual + mlp_output | |
| return hidden_states, present_key_value, attn_weights | |
| class LizzyPreTrainedModel(PreTrainedModel): | |
| config_class = LizzyConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["LizzyDecoderLayer"] | |
| _skip_keys_device_placement = ["past_key_values"] | |
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _supports_flex_attn = True | |
| _supports_attention_backend = True | |
| def _init_weights(self, module: nn.Module) -> None: | |
| std = self.config.initializer_range | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, (LizzyRMSNorm, nn.LayerNorm)): | |
| if hasattr(module, "weight") and module.weight is not None: | |
| module.weight.data.fill_(1.0) | |
| if hasattr(module, "bias") and module.bias is not None: | |
| module.bias.data.zero_() | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: str | os.PathLike[str] | None, | |
| *model_args: Any, | |
| **kwargs: Any, | |
| ) -> "LizzyPreTrainedModel": | |
| model = cast( | |
| "LizzyPreTrainedModel", | |
| super().from_pretrained( | |
| pretrained_model_name_or_path, | |
| *model_args, | |
| **kwargs, | |
| ), | |
| ) | |
| _refresh_attention_rope_buffers(model) | |
| if hasattr(model, "lm_head") and hasattr(model, "model"): | |
| tied_weights_keys = getattr(type(model), "_tied_weights_keys", None) | |
| if isinstance(tied_weights_keys, dict) and tied_weights_keys: | |
| model._tied_weights_keys = dict(tied_weights_keys) | |
| else: | |
| model._tied_weights_keys = { | |
| "lm_head.weight": "model.embed_tokens.weight", | |
| } | |
| model._tp_plan = {"lm_head": "colwise_rep"} | |
| model._pp_plan = {"lm_head": (["hidden_states"], ["logits"])} | |
| return model | |
| def load_state_dict( # type: ignore[override] | |
| self, | |
| state_dict: dict[str, torch.Tensor], | |
| strict: bool = True, | |
| assign: bool = False, | |
| ) -> Any: | |
| remapped_state_dict: dict[str, torch.Tensor] = {} | |
| for key, value in state_dict.items(): | |
| remapped_key = key | |
| if ".mlp.fc_in." in key: | |
| remapped_key = key.replace(".mlp.fc_in.", ".mlp.up_proj.") | |
| elif ".mlp.fc_out." in key: | |
| remapped_key = key.replace(".mlp.fc_out.", ".mlp.down_proj.") | |
| existing = remapped_state_dict.get(remapped_key) | |
| if existing is not None and not torch.equal(existing, value): | |
| msg = ( | |
| f"Conflicting legacy Lizzy MLP tensors" | |
| f" for key: {remapped_key}" | |
| ) | |
| raise ValueError(msg) | |
| remapped_state_dict[remapped_key] = value | |
| load_result = super().load_state_dict( | |
| remapped_state_dict, | |
| strict=strict, | |
| assign=assign, | |
| ) | |
| # RoPE buffers are intentionally non-persistent, so refresh them after | |
| # weight loading instead of trusting constructor-time allocations. | |
| _refresh_attention_rope_buffers(self) | |
| return load_result | |
| class LizzyModel(LizzyPreTrainedModel): | |
| def __init__(self, config: LizzyConfig) -> None: | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding( | |
| config.vocab_size, | |
| config.hidden_size, | |
| self.padding_idx, | |
| ) | |
| self.embed_positions = ( | |
| nn.Embedding(config.max_position_embeddings, config.hidden_size) | |
| if config.position_embedding_type == "absolute" | |
| else None | |
| ) | |
| self.layers = nn.ModuleList( | |
| LizzyDecoderLayer(config, layer_idx) | |
| for layer_idx in range(config.num_hidden_layers) | |
| ) | |
| self.norm = _make_norm( | |
| config.norm_type, | |
| config.hidden_size, | |
| config.norm_eps, | |
| has_bias=config.norm_has_bias, | |
| ) | |
| self.embd_dropout = nn.Dropout(config.embd_dropout) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value: nn.Embedding) -> None: | |
| self.embed_tokens = value | |
| def _build_attention_mask( | |
| self, | |
| attention_mask: torch.Tensor | None, | |
| *, | |
| batch_size: int, | |
| q_len: int, | |
| kv_len: int, | |
| kv_offset: int, | |
| cache_position: torch.Tensor, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| sliding_window: int | None = None, | |
| ) -> torch.Tensor: | |
| kv_len = ( | |
| int(kv_len.item()) if isinstance(kv_len, torch.Tensor) else int(kv_len) | |
| ) | |
| kv_offset = ( | |
| int(kv_offset.item()) | |
| if isinstance(kv_offset, torch.Tensor) | |
| else int(kv_offset) | |
| ) | |
| min_value = torch.finfo(dtype).min | |
| source_positions = cache_position.to(device=device).view(-1, 1) | |
| target_positions = torch.arange( | |
| kv_offset, | |
| kv_offset + kv_len, | |
| device=device, | |
| ).unsqueeze(0) | |
| causal = torch.zeros((q_len, kv_len), dtype=dtype, device=device) | |
| causal = causal.masked_fill(target_positions > source_positions, min_value) | |
| if sliding_window is not None: | |
| lower_bound = source_positions - int(sliding_window) + 1 | |
| causal = causal.masked_fill(target_positions < lower_bound, min_value) | |
| causal = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) | |
| if attention_mask is None: | |
| return causal | |
| if attention_mask.dim() != 2: | |
| msg = "attention_mask must be 2D [batch, sequence]." | |
| raise ValueError(msg) | |
| if attention_mask.shape[1] < kv_len: | |
| pad = torch.ones( | |
| (attention_mask.shape[0], kv_len - attention_mask.shape[1]), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| attention_mask = torch.cat([pad, attention_mask], dim=1) | |
| elif attention_mask.shape[1] > kv_len: | |
| attention_mask = attention_mask[:, -kv_len:] | |
| expanded = attention_mask[:, None, None, :].to(device=device) | |
| padding = (expanded == 0).to(dtype) * min_value | |
| return causal + padding | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | tuple[ | |
| tuple[torch.Tensor, torch.Tensor], ... | |
| ] | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| output_attentions: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| return_dict: bool | None = None, | |
| **kwargs: Any, | |
| ) -> BaseModelOutputWithPast | tuple[Any, ...]: | |
| if (input_ids is None) == (inputs_embeds is None): | |
| msg = "Exactly one of input_ids or inputs_embeds must be provided." | |
| raise ValueError(msg) | |
| output_attentions = ( | |
| bool(output_attentions) if output_attentions is not None else False | |
| ) | |
| output_hidden_states = ( | |
| bool(output_hidden_states) | |
| if output_hidden_states is not None | |
| else False | |
| ) | |
| use_cache = ( | |
| bool(use_cache) | |
| if use_cache is not None | |
| else bool(self.config.use_cache) | |
| ) | |
| return_dict = bool(return_dict) if return_dict is not None else True | |
| if inputs_embeds is None: | |
| hidden_states = self.embed_tokens(input_ids) | |
| batch_size, seq_len = input_ids.shape | |
| else: | |
| hidden_states = inputs_embeds | |
| batch_size, seq_len, _ = inputs_embeds.shape | |
| cache_object = ( | |
| past_key_values | |
| if _is_cache_object(past_key_values) | |
| else None | |
| ) | |
| if use_cache and _has_linear_attention(self.config): | |
| # Transformers 5.4 seeds `generate()` with an empty DynamicCache | |
| # for standard causal decoders. Hybrid Lizzy checkpoints need the | |
| # mixed cache below instead, because linear-attention layers read | |
| # DeltaNet convolution/recurrent state during the prefill pass. | |
| if cache_object is not None and not isinstance( | |
| cache_object, LizzyHybridDynamicCache, | |
| ): | |
| if int(cache_object.get_seq_length()) > 0: | |
| msg = ( | |
| "Hybrid Lizzy checkpoints require " | |
| "LizzyHybridDynamicCache once generation cache " | |
| "state is populated." | |
| ) | |
| raise ValueError(msg) | |
| cache_object = LizzyHybridDynamicCache(config=self.config) | |
| past_key_values = cache_object | |
| if use_cache and cache_object is None and past_key_values is None: | |
| if _has_linear_attention(self.config): | |
| # Linear-attention checkpoints need a mixed cache that can hold | |
| # both KV tensors and recurrent DeltaNet state. | |
| cache_object = LizzyHybridDynamicCache(config=self.config) | |
| else: | |
| cache_object = DynamicCache() | |
| past_key_values = cache_object | |
| if cache_object is not None: | |
| past_length = int(cache_object.get_seq_length()) | |
| else: | |
| past_length = _legacy_cache_length(past_key_values) | |
| cache_position = _normalize_cache_position(cache_position) | |
| if cache_position is None: | |
| cache_position = torch.arange( | |
| past_length, | |
| past_length + seq_len, | |
| dtype=torch.long, | |
| device=hidden_states.device, | |
| ) | |
| if position_ids is None: | |
| position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) | |
| if self.embed_positions is not None: | |
| hidden_states = hidden_states + self.embed_positions(position_ids) | |
| hidden_states = self.embd_dropout(hidden_states) | |
| if self.training and self.gradient_checkpointing: | |
| use_cache = False | |
| layer_types = list(self.config.layer_types) | |
| if not layer_types: | |
| layer_types = ["full_attention"] * len(self.layers) | |
| has_linear_attention = any( | |
| str(layer_type) == "linear_attention" for layer_type in layer_types | |
| ) | |
| _attn_impl = getattr(self.config, "_attn_implementation", "eager") | |
| if has_linear_attention and isinstance(attention_mask, dict): | |
| linear_attention_mask = attention_mask.get("linear_attention") | |
| else: | |
| linear_attention_mask = attention_mask | |
| if ( | |
| has_linear_attention | |
| and cache_object is not None | |
| and getattr(cache_object, "has_previous_state", False) | |
| ): | |
| linear_attention_mask = None | |
| elif ( | |
| has_linear_attention | |
| and attention_mask is not None | |
| and not isinstance(attention_mask, dict) | |
| and torch.all(attention_mask == 1) | |
| ): | |
| linear_attention_mask = None | |
| if ( | |
| _attn_impl == "flash_attention_2" | |
| and not isinstance(attention_mask, dict) | |
| ): | |
| # Flash attention handles causal masking (via is_causal) and | |
| # padding (via 2D mask) natively; skip building a 4D mask. | |
| attention_mask_mapping = { | |
| lt: attention_mask | |
| for lt in dict.fromkeys(layer_types) | |
| if lt != "linear_attention" | |
| } | |
| elif _attn_impl == "sdpa" and attention_mask is None: | |
| attention_mask_mapping = {} | |
| for layer_type in dict.fromkeys(layer_types): | |
| if layer_type == "linear_attention": | |
| continue | |
| if layer_type == "full_attention": | |
| # Match upstream decoder-only HF models: when SDPA sees | |
| # plain causal full attention with no padding mask to | |
| # preserve, let it use its native is_causal fast-path | |
| # instead of forcing an explicit 4D bias tensor. | |
| attention_mask_mapping[layer_type] = None | |
| continue | |
| layer_idx = layer_types.index(layer_type) | |
| if cache_object is not None: | |
| kv_len, kv_offset = cache_object.get_mask_sizes( | |
| seq_len, layer_idx, | |
| ) | |
| else: | |
| kv_len = past_length + seq_len | |
| kv_offset = 0 | |
| attention_mask_mapping[layer_type] = self._build_attention_mask( | |
| attention_mask, | |
| batch_size=batch_size, | |
| q_len=seq_len, | |
| kv_len=kv_len, | |
| kv_offset=kv_offset, | |
| cache_position=cache_position, | |
| device=hidden_states.device, | |
| dtype=hidden_states.dtype, | |
| sliding_window=( | |
| self.config.sliding_window | |
| if layer_type == "sliding_attention" | |
| else None | |
| ), | |
| ) | |
| elif isinstance(attention_mask, dict): | |
| attention_mask_mapping = { | |
| key: value | |
| for key, value in attention_mask.items() | |
| if key != "linear_attention" | |
| } | |
| else: | |
| attention_mask_mapping: dict[str, torch.Tensor] = {} | |
| for layer_type in dict.fromkeys(layer_types): | |
| if layer_type == "linear_attention": | |
| continue | |
| layer_idx = layer_types.index(layer_type) | |
| if cache_object is not None: | |
| kv_len, kv_offset = cache_object.get_mask_sizes( | |
| seq_len, layer_idx, | |
| ) | |
| else: | |
| kv_len = past_length + seq_len | |
| kv_offset = 0 | |
| attention_mask_mapping[layer_type] = self._build_attention_mask( | |
| attention_mask, | |
| batch_size=batch_size, | |
| q_len=seq_len, | |
| kv_len=kv_len, | |
| kv_offset=kv_offset, | |
| cache_position=cache_position, | |
| device=hidden_states.device, | |
| dtype=hidden_states.dtype, | |
| sliding_window=( | |
| self.config.sliding_window | |
| if layer_type == "sliding_attention" | |
| else None | |
| ), | |
| ) | |
| all_hidden_states = [] if output_hidden_states else None | |
| all_attentions = [] if output_attentions else None | |
| next_cache = ( | |
| cache_object | |
| if cache_object is not None | |
| else ([] if use_cache else None) | |
| ) | |
| gradient_checkpointing_func = getattr( | |
| self, | |
| "_gradient_checkpointing_func", | |
| checkpoint, | |
| ) | |
| for idx, layer in enumerate(self.layers): | |
| if output_hidden_states and all_hidden_states is not None: | |
| all_hidden_states.append(hidden_states) | |
| layer_type = ( | |
| layer_types[idx] | |
| if idx < len(layer_types) | |
| else "full_attention" | |
| ) | |
| if layer_type == "linear_attention": | |
| layer_attention_mask = linear_attention_mask | |
| else: | |
| layer_attention_mask = attention_mask_mapping[layer_type] | |
| if cache_object is not None: | |
| layer_past: Cache | tuple[ | |
| torch.Tensor, torch.Tensor | |
| ] | None = cache_object | |
| elif past_key_values is not None: | |
| layer_past = past_key_values[idx] | |
| if layer_past is not None and layer_past[0] is None: | |
| layer_past = None | |
| else: | |
| layer_past = None | |
| if self.training and self.gradient_checkpointing: | |
| def custom_forward(hidden_states: torch.Tensor) -> Any: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=layer_attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=None, | |
| cache_position=cache_position, | |
| use_cache=False, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| if output_attentions: | |
| return layer_outputs[0], layer_outputs[2] | |
| return layer_outputs[0] | |
| checkpointed_outputs = gradient_checkpointing_func( | |
| custom_forward, hidden_states, | |
| ) | |
| if output_attentions: | |
| hidden_states, attn_weights = checkpointed_outputs | |
| else: | |
| hidden_states = checkpointed_outputs | |
| attn_weights = None | |
| present = None | |
| else: | |
| hidden_states, present, attn_weights = layer( | |
| hidden_states, | |
| attention_mask=layer_attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=layer_past, | |
| cache_position=cache_position, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| if use_cache and next_cache is not None and cache_object is None: | |
| next_cache.append(present) | |
| if output_attentions and all_attentions is not None: | |
| all_attentions.append(attn_weights) | |
| hidden_states = self.norm(hidden_states) | |
| if output_hidden_states and all_hidden_states is not None: | |
| all_hidden_states.append(hidden_states) | |
| past_key_values_output: Cache | tuple[ | |
| tuple[torch.Tensor, torch.Tensor], ... | |
| ] | None = None | |
| if use_cache and next_cache is not None: | |
| if cache_object is not None: | |
| past_key_values_output = cache_object | |
| else: | |
| past_key_values_output = tuple(next_cache) | |
| if not return_dict: | |
| output: tuple[Any, ...] = (hidden_states,) | |
| if past_key_values_output is not None: | |
| output = output + (past_key_values_output,) | |
| if output_hidden_states and all_hidden_states is not None: | |
| output = output + (tuple(all_hidden_states),) | |
| if output_attentions and all_attentions is not None: | |
| output = output + (tuple(all_attentions),) | |
| return output | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values_output, | |
| hidden_states=( | |
| tuple(all_hidden_states) | |
| if all_hidden_states is not None | |
| else None | |
| ), | |
| attentions=( | |
| tuple(all_attentions) | |
| if all_attentions is not None | |
| else None | |
| ), | |
| ) | |
| class LizzyForCausalLM(LizzyPreTrainedModel, GenerationMixin): | |
| config_class = LizzyConfig | |
| # Transformers 5.4 expects an expanded target->source mapping here rather than | |
| # the older list-based shorthand. | |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} | |
| _tp_plan = {"lm_head": "colwise_rep"} | |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} | |
| def __init__(self, config: LizzyConfig) -> None: | |
| super().__init__(config) | |
| self.model = LizzyModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.model.get_input_embeddings() | |
| def set_input_embeddings(self, value: nn.Embedding) -> None: | |
| self.model.set_input_embeddings(value) | |
| def get_output_embeddings(self) -> nn.Module: | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: | |
| self.lm_head = new_embeddings | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor, | |
| past_key_values: Cache | tuple[ | |
| tuple[torch.Tensor, torch.Tensor], ... | |
| ] | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs: Any, | |
| ) -> dict[str, Any]: | |
| past_length = 0 | |
| if past_key_values is not None: | |
| if _is_cache_object(past_key_values): | |
| past_length = int(past_key_values.get_seq_length()) | |
| else: | |
| past_length = _legacy_cache_length(past_key_values) | |
| cache_position = _normalize_cache_position(cache_position) | |
| if cache_position is None: | |
| if past_key_values is not None: | |
| new_tokens = input_ids.shape[1] - past_length | |
| if new_tokens <= 0: | |
| new_tokens = 1 | |
| cache_position = torch.arange( | |
| past_length, | |
| past_length + new_tokens, | |
| device=input_ids.device, | |
| ) | |
| else: | |
| cache_position = torch.arange( | |
| input_ids.shape[1], | |
| device=input_ids.device, | |
| ) | |
| if past_key_values is not None: | |
| input_ids = input_ids[:, -cache_position.shape[0] :] | |
| if attention_mask is not None: | |
| attn_mask_idx = (past_length + input_ids.shape[1]) | |
| attention_mask = attention_mask[:, -attn_mask_idx :] | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs: dict[str, Any] = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids.contiguous()} | |
| model_inputs.update( | |
| { | |
| "past_key_values": past_key_values, | |
| "attention_mask": attention_mask, | |
| "cache_position": cache_position, | |
| "use_cache": kwargs.get("use_cache", self.config.use_cache), | |
| }, | |
| ) | |
| return model_inputs | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | tuple[ | |
| tuple[torch.Tensor, torch.Tensor], ... | |
| ] | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| output_attentions: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| return_dict: bool | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Any, | |
| ) -> CausalLMOutputWithPast | tuple[Any, ...]: | |
| # HF eval loaders call `forward()` without an explicit return_dict, | |
| # so local Lizzy exports must normalize the optional flag first. | |
| return_dict = bool(return_dict) if return_dict is not None else True | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state | |
| slice_indices = ( | |
| slice(-logits_to_keep, None) | |
| if isinstance(logits_to_keep, int) | |
| else logits_to_keep | |
| ) | |
| if labels is not None: | |
| full_logits = self.lm_head(hidden_states) | |
| logits = full_logits[:, slice_indices, :] | |
| else: | |
| full_logits = None | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = full_logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| if loss is not None: | |
| output = (loss,) + output | |
| return output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |