Lizzy-7B / modeling_lizzy.py
relogu's picture
Initial commit
edbbb7f
from __future__ import annotations
import math
import os
from typing import Any, cast
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
try:
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
except ImportError:
ROPE_INIT_FUNCTIONS = {}
try:
from fla.modules import FusedRMSNormGated, ShortConvolution
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
except ImportError:
chunk_gated_delta_rule = None
fused_recurrent_gated_delta_rule = None
FusedRMSNormGated = None
ShortConvolution = None
from .configuration_lizzy import LizzyConfig
class LizzyRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states.to(input_dtype)
def _make_norm(
norm_type: str,
hidden_size: int,
eps: float,
*,
has_bias: bool,
) -> nn.Module:
if norm_type == "rmsnorm":
return LizzyRMSNorm(hidden_size, eps=eps)
if norm_type == "layernorm":
return nn.LayerNorm(
hidden_size,
eps=eps,
elementwise_affine=True,
bias=has_bias,
)
msg = f"Unsupported norm_type: {norm_type}"
raise ValueError(msg)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _legacy_cache_length(
past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None,
) -> int:
if (
isinstance(past_key_values, tuple)
and len(past_key_values) > 0
and past_key_values[0] is not None
and past_key_values[0][0] is not None
):
return int(past_key_values[0][0].shape[2])
return 0
def _normalize_cache_position(
cache_position: torch.Tensor | None,
) -> torch.Tensor | None:
if cache_position is None:
return None
if cache_position.dim() == 0:
return cache_position.view(1)
if cache_position.dim() > 1:
return cache_position[0]
return cache_position
def _is_cache_object(value: Any) -> bool:
return isinstance(value, Cache) or isinstance(value, LizzyHybridDynamicCache)
def _compute_default_rope_parameters(
config: LizzyConfig,
device: torch.device,
) -> tuple[torch.Tensor, float]:
inv_freq = 1.0 / (
config.rope_theta
** (
torch.arange(0, config.head_dim, 2, device=device, dtype=torch.float32)
/ config.head_dim
)
)
return inv_freq, 1.0
def _compute_yarn_rope_parameters(
config: LizzyConfig,
device: torch.device,
) -> tuple[torch.Tensor, float]:
rope_scaling = dict(config.rope_scaling or {})
factor = float(rope_scaling["factor"])
attention_factor = rope_scaling.get("attention_factor")
mscale = rope_scaling.get("mscale")
mscale_all_dim = rope_scaling.get("mscale_all_dim")
original_max_position_embeddings = int(
rope_scaling.get("original_max_position_embeddings")
or config.max_position_embeddings
)
def get_mscale(scale: float, mscale_value: float = 1.0) -> float:
if scale <= 1.0:
return 1.0
return 0.1 * mscale_value * math.log(scale) + 1.0
if attention_factor is None:
if mscale is not None and mscale_all_dim is not None:
attention_factor = float(
get_mscale(factor, float(mscale))
/ get_mscale(factor, float(mscale_all_dim))
)
else:
attention_factor = get_mscale(factor)
beta_fast = float(rope_scaling.get("beta_fast") or 32.0)
beta_slow = float(rope_scaling.get("beta_slow") or 1.0)
truncate = bool(rope_scaling.get("truncate", True))
dim = config.head_dim
def find_correction_dim(
num_rotations: float,
*,
dim: int,
base: float,
max_position_embeddings: int,
) -> float:
return (
dim
* math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
/ (2 * math.log(base))
)
def find_correction_range(
low_rot: float,
high_rot: float,
*,
dim: int,
base: float,
max_position_embeddings: int,
truncate: bool,
) -> tuple[float, float]:
low = find_correction_dim(
low_rot,
dim=dim,
base=base,
max_position_embeddings=max_position_embeddings,
)
high = find_correction_dim(
high_rot,
dim=dim,
base=base,
max_position_embeddings=max_position_embeddings,
)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0.0), min(high, dim - 1.0)
def linear_ramp_factor(
min_value: float,
max_value: float,
dim: int,
) -> torch.Tensor:
if min_value == max_value:
max_value += 0.001
linear_func = (
torch.arange(dim, dtype=torch.float32, device=device) - min_value
) / (max_value - min_value)
return torch.clamp(linear_func, 0, 1)
pos_freqs = config.rope_theta ** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(
beta_fast,
beta_slow,
dim=dim,
base=config.rope_theta,
max_position_embeddings=original_max_position_embeddings,
truncate=truncate,
)
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, float(attention_factor)
def _compute_rope_parameters(
config: LizzyConfig,
device: torch.device,
*,
seq_len: int | torch.Tensor | None = None,
rope_type_override: str | None = None,
) -> tuple[torch.Tensor, float]:
rope_scaling = dict(config.rope_scaling or {})
rope_type = rope_type_override
if rope_type is None:
if not rope_scaling:
return _compute_default_rope_parameters(config, device)
rope_type = str(
rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
)
if rope_type == "default":
return _compute_default_rope_parameters(config, device)
if rope_type == "yarn":
return _compute_yarn_rope_parameters(config, device)
if not rope_scaling:
return _compute_default_rope_parameters(config, device)
rope_init_fn = (
ROPE_INIT_FUNCTIONS.get(rope_type) or ROPE_INIT_FUNCTIONS.get("default")
)
if rope_init_fn is None:
return _compute_default_rope_parameters(config, device)
inv_freq, attention_factor = rope_init_fn(config, device, seq_len=seq_len)
return inv_freq.to(device=device, dtype=torch.float32), float(attention_factor)
def _looks_like_legacy_interval_rope_lizzy(config: LizzyConfig) -> bool:
rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or [])
if rope_layer_flags and not all(bool(item) for item in rope_layer_flags):
return False
layer_types = list(getattr(config, "layer_types", None) or [])
if layer_types and any(str(item) != "full_attention" for item in layer_types):
return False
return (
str(getattr(config, "position_embedding_type", "")).lower() == "rope"
and not bool(getattr(config, "rope_scaling", None))
and int(getattr(config, "num_hidden_layers", 0) or 0) == 36
and int(getattr(config, "hidden_size", 0) or 0) == 2048
and int(getattr(config, "num_attention_heads", 0) or 0) == 16
and int(getattr(config, "num_key_value_heads", 0) or 0) == 4
and math.isclose(
float(getattr(config, "rope_theta", 0.0) or 0.0), 5_000_000.0
)
and not bool(getattr(config, "use_post_attn_norm", False))
and not bool(getattr(config, "use_post_mlp_norm", False))
and not bool(getattr(config, "use_qk_norm", False))
)
def _get_no_rope_layer_interval(config: LizzyConfig) -> int | None:
value = getattr(config, "no_rope_layer_interval", None)
if value is not None:
value = int(value)
if value > 0:
return value
if _looks_like_legacy_interval_rope_lizzy(config):
# Backward-compatible fallback for already-uploaded Lizzy
# checkpoints that should use NoPE on every 4th layer.
return 4
return None
def _get_rope_layer_flag(config: LizzyConfig, layer_idx: int) -> bool:
rope_enabled = str(
getattr(config, "position_embedding_type", "rope")
).lower() == "rope"
rope_layer_flags = list(getattr(config, "rope_layer_flags", None) or [])
no_rope_layer_interval = _get_no_rope_layer_interval(config)
if (
no_rope_layer_interval is not None
and (
layer_idx >= len(rope_layer_flags)
or not rope_layer_flags
or all(bool(item) for item in rope_layer_flags)
)
):
return rope_enabled and ((layer_idx + 1) % no_rope_layer_interval != 0)
if 0 <= layer_idx < len(rope_layer_flags):
return rope_enabled and bool(rope_layer_flags[layer_idx])
return rope_enabled
def _get_layer_layout(config: LizzyConfig, layer_idx: int) -> str:
layer_layouts = list(getattr(config, "layer_layouts", None) or [])
if 0 <= layer_idx < len(layer_layouts):
return str(layer_layouts[layer_idx])
if bool(getattr(config, "use_post_attn_norm", False)) or bool(
getattr(config, "use_post_mlp_norm", False)
):
return "decoder_postnorm"
return "decoder_prenorm"
def _has_linear_attention(config: LizzyConfig) -> bool:
return any(
str(layer_type) == "linear_attention"
for layer_type in list(getattr(config, "layer_types", None) or [])
)
class LizzyHybridDynamicCache:
"""Cache for Lizzy checkpoints with mixed full and linear attention."""
is_compileable = False
def __init__(self, config: LizzyConfig) -> None:
super().__init__()
self.layer_types = list(config.layer_types)
self.transformer_layers = [
idx
for idx, layer_type in enumerate(self.layer_types)
if layer_type == "full_attention"
]
self.last_linear_layer = (
len(self.layer_types)
- 1
- self.layer_types[::-1].index("linear_attention")
)
self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
self.key_cache = [None for _ in range(config.num_hidden_layers)]
self.value_cache = [None for _ in range(config.num_hidden_layers)]
self.conv_states_q = [None for _ in range(config.num_hidden_layers)]
self.conv_states_k = [None for _ in range(config.num_hidden_layers)]
self.conv_states_v = [None for _ in range(config.num_hidden_layers)]
def __len__(self) -> int:
return len(self.layer_types)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
del cache_kwargs
if self.key_cache[layer_idx] is None:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states],
dim=2,
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states],
dim=2,
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
batch_size = beam_idx.shape[0]
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] is not None:
if self.key_cache[layer_idx].shape[0] < batch_size:
expand_ratio = (
batch_size // self.key_cache[layer_idx].shape[0]
)
self.key_cache[layer_idx] = (
self.key_cache[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
self.value_cache[layer_idx] = (
self.value_cache[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
0,
beam_idx.to(device),
)
self.value_cache[layer_idx] = (
self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
)
if self.conv_states_q[layer_idx] is not None:
if self.conv_states_q[layer_idx].shape[0] < batch_size:
expand_ratio = (
batch_size // self.conv_states_q[layer_idx].shape[0]
)
self.conv_states_q[layer_idx] = (
self.conv_states_q[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
self.conv_states_k[layer_idx] = (
self.conv_states_k[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
self.conv_states_v[layer_idx] = (
self.conv_states_v[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
self.recurrent_states[layer_idx] = (
self.recurrent_states[layer_idx].repeat_interleave(
expand_ratio, dim=0,
)
)
device = self.conv_states_q[layer_idx].device
self.conv_states_q[layer_idx] = (
self.conv_states_q[layer_idx].index_select(
0,
beam_idx.to(device),
)
)
self.conv_states_k[layer_idx] = (
self.conv_states_k[layer_idx].index_select(
0,
beam_idx.to(device),
)
)
self.conv_states_v[layer_idx] = (
self.conv_states_v[layer_idx].index_select(
0,
beam_idx.to(device),
)
)
self.recurrent_states[layer_idx] = (
self.recurrent_states[layer_idx].index_select(
0,
beam_idx.to(device),
)
)
def get_seq_length(self, layer_idx: int | None = 0) -> int:
if not self.transformer_layers:
return 0
layer_idx = (
self.transformer_layers[0]
if layer_idx not in self.transformer_layers
else layer_idx
)
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]:
del layer_idx
kv_offset = 0
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, kv_offset
@property
def has_previous_state(self) -> bool:
# Mirror the upstream contract: once the final linear layer has cached
# its conv state, single-token decode can switch to the recurrent path.
return self.conv_states_q[self.last_linear_layer] is not None
class LizzyHybridRMSNormGated(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
hidden_states: torch.Tensor,
gate: torch.Tensor | None = None,
) -> torch.Tensor:
if gate is None:
msg = "gate is required for gated RMSNorm."
raise ValueError(msg)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
class LizzyHybridShortConvolution(nn.Conv1d):
def __init__(
self,
hidden_size: int,
kernel_size: int,
bias: bool = False,
activation: str | None = "silu",
) -> None:
super().__init__(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
groups=hidden_size,
padding=kernel_size - 1,
bias=bias,
)
self.hidden_size = hidden_size
self.conv_kernel_size = kernel_size
self.act_fn = ACT2FN[activation]
def forward(
self,
hidden_states: torch.Tensor,
cache: torch.Tensor | None = None,
use_precomputed: bool = False,
**kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
del kwargs
seq_len, dim = hidden_states.shape[-2:]
hidden_states = hidden_states.transpose(1, 2)
if use_precomputed:
if cache is None:
msg = "cache is required when use_precomputed=True."
raise ValueError(msg)
x_with_state = torch.cat([cache, hidden_states], dim=-1)
out = F.conv1d(
x_with_state,
self.weight,
self.bias,
padding=0,
groups=dim,
)
conv_state = x_with_state[:, :, 1:]
else:
out = F.conv1d(
hidden_states,
self.weight,
self.bias,
padding=self.conv_kernel_size - 1,
groups=dim,
)
out = out[:, :, :seq_len]
conv_state = F.pad(
hidden_states,
(self.conv_kernel_size - 1 - hidden_states.shape[-1], 0),
)
out = self.act_fn(out)
return out.transpose(1, 2), conv_state
def _apply_mask_to_padding_states(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None,
) -> torch.Tensor:
# Match the upstream hybrid implementation: silence padded tokens before
# the DeltaNet projections so recurrent state does not absorb padding.
if (
attention_mask is not None
and attention_mask.shape[1] > 1
and attention_mask.shape[0] > 1
):
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
def _l2norm(
x: torch.Tensor,
dim: int = -1,
eps: float = 1e-6,
) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def _torch_chunk_gated_delta_rule(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
chunk_size: int = 64,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(
torch.ones(
chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device,
),
diagonal=0,
)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for idx in range(1, chunk_size):
row = attn[..., idx, :idx].clone()
sub = attn[..., :idx, :idx].clone()
attn[..., idx, :idx] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(
chunk_size,
dtype=attn.dtype,
device=attn.device,
)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(
torch.ones(
chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device,
),
diagonal=1,
)
for idx in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, idx], key[:, :, idx], value[:, :, idx]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, idx]).masked_fill_(
mask,
0,
)
v_prime = (k_cumdecay[:, :, idx]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, idx, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, idx] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, idx, -1, None, None].exp()
+ (
k_i
* (g[:, :, idx, -1, None] - g[:, :, idx]).exp()[..., None]
).transpose(-1, -2)
@ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0],
core_attn_out.shape[1],
-1,
core_attn_out.shape[-1],
)
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def _torch_recurrent_gated_delta_rule(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor | None,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(
batch_size,
num_heads,
sequence_length,
v_head_dim,
).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for idx in range(sequence_length):
q_t = query[:, :, idx]
k_t = key[:, :, idx]
v_t = value[:, :, idx]
g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, idx].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, idx] = (
last_recurrent_state * q_t.unsqueeze(-1)
).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
class LizzyHybridGatedDeltaNet(nn.Module):
def __init__(self, config: LizzyConfig, layer_idx: int) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.layer_idx = layer_idx
self.conv_kernel_size = config.linear_conv_kernel_dim
self.allow_neg_eigval = config.linear_allow_neg_eigval
self.eps = config.rms_norm_eps
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.a_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.b_proj = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
# Step-02 conversion runs on CPU by default, even on GPU nodes. In that
# flow Triton-backed FLA kernels will crash as soon as a CPU tensor
# reaches them, so the wrapper can force the pure PyTorch fallback for
# Hybrid layers via an environment switch.
disable_fla_fast_path = os.environ.get(
"LIZZY_DISABLE_HYBRID_FLA",
"",
).strip().lower() in {"1", "true", "yes", "on"}
use_fla_fast_path = (
not disable_fla_fast_path
and
torch.cuda.is_available()
and ShortConvolution is not None
and chunk_gated_delta_rule is not None
and fused_recurrent_gated_delta_rule is not None
and FusedRMSNormGated is not None
)
# Keep the fast-path contract when FLA is present, but fall back to a
# local implementation so the public Lizzy artifact never depends on
# family-specific Transformers remote code.
conv1d_class = (
ShortConvolution if use_fla_fast_path else LizzyHybridShortConvolution
)
self.q_conv1d = conv1d_class(
hidden_size=self.key_dim,
kernel_size=self.conv_kernel_size,
bias=False,
activation="silu",
)
self.k_conv1d = conv1d_class(
hidden_size=self.key_dim,
kernel_size=self.conv_kernel_size,
bias=False,
activation="silu",
)
self.v_conv1d = conv1d_class(
hidden_size=self.value_dim,
kernel_size=self.conv_kernel_size,
bias=False,
activation="silu",
)
a = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(
config.linear_a_log_min,
config.linear_a_log_max,
)
self.A_log = nn.Parameter(torch.log(a))
dt = torch.exp(
torch.rand(self.num_v_heads)
* (math.log(config.linear_dt_max) - math.log(config.linear_dt_min))
+ math.log(config.linear_dt_min)
)
dt = torch.clamp(dt, min=config.linear_dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
self.o_norm = (
LizzyHybridRMSNormGated(self.head_v_dim, eps=1e-5)
if not use_fla_fast_path
else FusedRMSNormGated(
self.head_v_dim,
eps=1e-5,
device=torch.cuda.current_device(),
dtype=(
config.dtype
if config.dtype is not None
else torch.get_default_dtype()
),
)
)
self.chunk_gated_delta_rule = (
chunk_gated_delta_rule
if use_fla_fast_path
else _torch_chunk_gated_delta_rule
)
self.recurrent_gated_delta_rule = (
(
fused_recurrent_gated_delta_rule
if use_fla_fast_path
else _torch_recurrent_gated_delta_rule
)
)
def forward(
self,
hidden_states: torch.Tensor,
cache_params: LizzyHybridDynamicCache | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs: Any,
) -> torch.Tensor:
del kwargs
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
use_cache = cache_params is not None
use_precomputed = (
use_cache
and getattr(cache_params, "has_previous_state", False)
and seq_len == 1
)
conv_state_q = (
cache_params.conv_states_q[self.layer_idx] if cache_params else None
)
conv_state_k = (
cache_params.conv_states_k[self.layer_idx] if cache_params else None
)
conv_state_v = (
cache_params.conv_states_v[self.layer_idx] if cache_params else None
)
recurrent_state = (
cache_params.recurrent_states[self.layer_idx] if cache_params else None
)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q, new_conv_state_q = self.q_conv1d(
q,
cache=conv_state_q,
use_precomputed=use_precomputed,
output_final_state=use_cache,
)
k, new_conv_state_k = self.k_conv1d(
k,
cache=conv_state_k,
use_precomputed=use_precomputed,
output_final_state=use_cache,
)
v, new_conv_state_v = self.v_conv1d(
v,
cache=conv_state_v,
use_precomputed=use_precomputed,
output_final_state=use_cache,
)
if cache_params is not None:
cache_params.conv_states_q[self.layer_idx] = new_conv_state_q
cache_params.conv_states_k[self.layer_idx] = new_conv_state_k
cache_params.conv_states_v[self.layer_idx] = new_conv_state_v
q = q.view(batch_size, seq_len, -1, self.head_k_dim)
k = k.view(batch_size, seq_len, -1, self.head_k_dim)
v = v.view(batch_size, seq_len, -1, self.head_v_dim)
if self.num_v_heads > self.num_k_heads:
expand_ratio = self.num_v_heads // self.num_k_heads
q = q.repeat_interleave(expand_ratio, dim=2)
k = k.repeat_interleave(expand_ratio, dim=2)
beta = self.b_proj(hidden_states).sigmoid()
if self.allow_neg_eigval:
beta = beta * 2.0
g = -self.A_log.float().exp() * F.softplus(
self.a_proj(hidden_states).float() + self.dt_bias
)
if use_precomputed:
output, new_recurrent_state = self.recurrent_gated_delta_rule(
q,
k,
v,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
)
else:
output, new_recurrent_state = self.chunk_gated_delta_rule(
q,
k,
v,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
)
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = new_recurrent_state
gate = self.g_proj(hidden_states)
output = output.reshape(-1, self.head_v_dim)
gate = gate.reshape(-1, self.head_v_dim)
output = self.o_norm(output, gate)
output = output.reshape(batch_size, seq_len, -1)
output = self.o_proj(output)
return output
class LizzyLinearAttention(nn.Module):
def __init__(self, config: LizzyConfig, layer_idx: int) -> None:
super().__init__()
self.layer_idx = layer_idx
self.inner = LizzyHybridGatedDeltaNet(config, layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_value: Cache | None = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs: Any,
) -> tuple[
torch.Tensor,
Cache | None,
torch.Tensor | None,
]:
del kwargs, output_attentions
output = self.inner(
hidden_states=hidden_states,
cache_params=(
past_key_value if _is_cache_object(past_key_value) else None
),
attention_mask=attention_mask,
)
present = past_key_value if use_cache else None
return output, present, None
class LizzyAttention(nn.Module):
def __init__(self, config: LizzyConfig, layer_idx: int) -> None:
super().__init__()
self.is_causal = True
self.config = config
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.position_embedding_type = config.position_embedding_type
self.layer_type = (
str(config.layer_types[layer_idx])
if layer_idx < len(config.layer_types)
else "full_attention"
)
self.use_rope = _get_rope_layer_flag(config, layer_idx)
self._rope_type_override = str(
dict(config.rope_type_overrides or {}).get(self.layer_type) or ""
) or None
if (
self._rope_type_override is None
and self.layer_type == "sliding_attention"
and bool(config.rope_scaling)
and config.use_post_attn_norm
and config.use_post_mlp_norm
and config.use_qk_norm
and any(str(item) == "full_attention" for item in config.layer_types)
):
self._rope_type_override = "default"
self.sliding_window = None
if self.layer_type == "sliding_attention":
self.sliding_window = config.sliding_window
q_dim = self.num_heads * self.head_dim
kv_dim = self.num_key_value_heads * self.head_dim
self.q_proj = nn.Linear(
config.hidden_size,
q_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
config.hidden_size,
kv_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
config.hidden_size,
kv_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
q_dim,
config.hidden_size,
bias=config.attention_bias,
)
self.q_norm = (
_make_norm(config.qk_norm_type, q_dim, config.norm_eps, has_bias=False)
if config.use_qk_norm
else None
)
self.k_norm = (
_make_norm(config.qk_norm_type, kv_dim, config.norm_eps, has_bias=False)
if config.use_qk_norm
else None
)
self._rope_requires_runtime_update = False
if self.use_rope:
rope_scaling = dict(config.rope_scaling or {})
rope_type = self._rope_type_override or str(
rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
)
self._rope_requires_runtime_update = rope_type == "dynamic"
if self._rope_requires_runtime_update:
self.register_buffer("_rope_inv_freq", None, persistent=False)
self.register_buffer(
"_rope_attention_factor", None, persistent=False,
)
else:
inv_freq, attention_factor = _compute_rope_parameters(
config,
device=torch.device("cpu"),
seq_len=config.max_position_embeddings,
rope_type_override=self._rope_type_override,
)
self.register_buffer("_rope_inv_freq", inv_freq, persistent=False)
self.register_buffer(
"_rope_attention_factor",
torch.tensor(float(attention_factor), dtype=torch.float32),
persistent=False,
)
else:
self.register_buffer("_rope_inv_freq", None, persistent=False)
self.register_buffer("_rope_attention_factor", None, persistent=False)
def _build_rope(
self,
position_ids: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_rope:
msg = "RoPE requested but rope buffer is not initialized."
raise RuntimeError(msg)
inv_freq = self._rope_inv_freq
attention_factor_tensor = self._rope_attention_factor
if (
inv_freq is None
or attention_factor_tensor is None
or self._rope_requires_runtime_update
):
# Keep the sequence-length hint as a tensor so TorchDynamo/vLLM
# can trace this path without requiring capture_scalar_outputs.
# When low-memory loading leaves the non-persistent cache unset,
# rebuild from config for this forward only instead of mutating
# buffers inside the compiled graph.
seq_len = (
torch.max(position_ids) + 1 if position_ids.numel() > 0 else None
)
inv_freq, attention_factor = _compute_rope_parameters(
self.config,
device=device,
seq_len=seq_len,
rope_type_override=self._rope_type_override,
)
attention_factor_tensor = torch.tensor(
float(attention_factor),
device=device,
dtype=torch.float32,
)
else:
inv_freq = inv_freq.to(device=device)
attention_factor_tensor = attention_factor_tensor.to(
device=device,
dtype=torch.float32,
)
# Mirror the upstream HF decoder-only rotary path closely here.
# The matmul-based construction is slightly more numerically stable
# than the generic einsum formulation for strict parity probes.
inv_freq_expanded = (
inv_freq[None, :, None]
.to(device=device, dtype=torch.float32)
.expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].to(torch.float32)
angles = torch.matmul(
inv_freq_expanded,
position_ids_expanded,
).transpose(1, 2)
angles = torch.cat((angles, angles), dim=-1)
cos = angles.cos().unsqueeze(1) * attention_factor_tensor
sin = angles.sin().unsqueeze(1) * attention_factor_tensor
cos = cos.to(dtype)
sin = sin.to(dtype)
return cos, sin
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None,
cache_position: torch.Tensor | None = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs: Any,
) -> tuple[
torch.Tensor,
Cache | tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor | None,
]:
batch_size, q_len, _ = hidden_states.shape
cache_position = _normalize_cache_position(cache_position)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if self.q_norm is not None:
query_states = self.q_norm(query_states)
if self.k_norm is not None:
key_states = self.k_norm(key_states)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim,
)
query_states = query_states.transpose(1, 2)
key_states = key_states.view(
batch_size,
q_len,
self.num_key_value_heads,
self.head_dim,
)
key_states = key_states.transpose(1, 2)
value_states = value_states.view(
batch_size,
q_len,
self.num_key_value_heads,
self.head_dim,
)
value_states = value_states.transpose(1, 2)
if self.use_rope:
if position_ids is None:
msg = "position_ids are required for rope attention."
raise ValueError(msg)
cos, sin = self._build_rope(
position_ids, hidden_states.device, query_states.dtype,
)
query_states, key_states = _apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
)
if _is_cache_object(past_key_value):
if use_cache:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
cache_kwargs={"cache_position": cache_position},
)
present_key_value = past_key_value
elif self.layer_idx < len(past_key_value):
past_key, past_value = past_key_value[self.layer_idx]
if past_key is not None and past_value is not None:
key_states = torch.cat([past_key, key_states], dim=2)
value_states = torch.cat([past_value, value_states], dim=2)
present_key_value = None
else:
present_key_value = None
elif past_key_value is not None:
past_key, past_value = past_key_value
key_states = torch.cat([past_key, key_states], dim=2)
value_states = torch.cat([past_value, value_states], dim=2)
present_key_value = (key_states, value_states) if use_cache else None
else:
present_key_value = (key_states, value_states) if use_cache else None
attention_interface = None
attn_impl = getattr(self.config, "_attn_implementation", "eager")
if attn_impl == "flex_attention" and self.head_dim < 16:
attn_impl = "sdpa"
if attn_impl != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS.get(attn_impl)
if attention_interface is not None:
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.contiguous()
else:
if self.num_key_value_heads != self.num_heads:
key_states = key_states.repeat_interleave(
self.num_key_value_groups, dim=1,
)
value_states = value_states.repeat_interleave(
self.num_key_value_groups, dim=1,
)
attn_weights = torch.matmul(
query_states,
key_states.transpose(-1, -2),
) * self.scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights.to(query_states.dtype)
attn_weights = F.dropout(
attn_weights,
p=self.attention_dropout if self.training else 0.0,
training=self.training,
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, present_key_value, attn_weights
def _refresh_attention_rope_buffers(module: nn.Module) -> None:
"""Rebuild non-persistent RoPE buffers after checkpoint load."""
for child in module.modules():
if not isinstance(child, LizzyAttention):
continue
should_use_rope = _get_rope_layer_flag(child.config, child.layer_idx)
child.use_rope = should_use_rope
if not should_use_rope:
child._rope_requires_runtime_update = False
child._rope_inv_freq = None
child._rope_attention_factor = None
continue
rope_scaling = dict(child.config.rope_scaling or {})
rope_type = child._rope_type_override or str(
rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
)
child._rope_requires_runtime_update = rope_type == "dynamic"
if child._rope_requires_runtime_update:
child._rope_inv_freq = None
child._rope_attention_factor = None
continue
# These buffers are derived from config rather than serialized weights.
# Recompute them after load so low-memory materialization cannot leave
# stale or uninitialized rotary state behind.
inv_freq, attention_factor = _compute_rope_parameters(
child.config,
device=torch.device("cpu"),
seq_len=child.config.max_position_embeddings,
rope_type_override=child._rope_type_override,
)
child._rope_inv_freq = inv_freq
child._rope_attention_factor = torch.tensor(
float(attention_factor),
dtype=torch.float32,
)
class LizzyMLP(nn.Module):
def __init__(self, config: LizzyConfig) -> None:
super().__init__()
self.config = config
self.act = ACT2FN[config.hidden_act]
self.gate_proj = (
nn.Linear(
config.hidden_size,
config.intermediate_size,
bias=config.mlp_bias,
)
if config.mlp_type == "gated"
else None
)
self.up_proj = nn.Linear(
config.hidden_size,
config.intermediate_size,
bias=config.mlp_bias,
)
self.down_proj = nn.Linear(
config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.gate_proj is None and self.config.mlp_type == "gated":
msg = "Missing gated MLP projection layers."
raise RuntimeError(msg)
if self.config.mlp_type == "gated":
if self.gate_proj is None:
msg = "Missing gated MLP projection layers."
raise RuntimeError(msg)
return self.down_proj(self.act(
self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
return self.down_proj(self.act(self.up_proj(hidden_states)))
class LizzyDecoderLayer(nn.Module):
def __init__(self, config: LizzyConfig, layer_idx: int) -> None:
super().__init__()
self.layer_type = (
str(config.layer_types[layer_idx])
if layer_idx < len(config.layer_types)
else "full_attention"
)
self.layer_layout = _get_layer_layout(config, layer_idx)
self.self_attn = (
LizzyAttention(config, layer_idx)
if self.layer_type != "linear_attention"
else None
)
self.linear_attn = (
LizzyLinearAttention(config, layer_idx)
if self.layer_type == "linear_attention"
else None
)
self.mlp = LizzyMLP(config)
self.pre_attn_norm = (
_make_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
has_bias=config.norm_has_bias,
)
if self.layer_layout == "decoder_prenorm"
else None
)
self.pre_mlp_norm = (
_make_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
has_bias=config.norm_has_bias,
)
if self.layer_layout == "decoder_prenorm"
else None
)
self.post_attn_norm = (
_make_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
has_bias=config.norm_has_bias,
)
if self.layer_layout == "decoder_postnorm"
else None
)
self.post_mlp_norm = (
_make_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
has_bias=config.norm_has_bias
)
if self.layer_layout == "decoder_postnorm"
else None
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_value: Cache | tuple[torch.Tensor, torch.Tensor] | None = None,
cache_position: torch.Tensor | None = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs: Any,
) -> tuple[
torch.Tensor,
Cache | tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor | None,
]:
residual = hidden_states
attn_inputs = (
self.pre_attn_norm(hidden_states)
if self.pre_attn_norm is not None
else hidden_states
)
if self.linear_attn is not None:
attn_output, present_key_value, attn_weights = self.linear_attn(
attn_inputs,
attention_mask=attention_mask,
past_key_value=(
past_key_value if _is_cache_object(past_key_value) else None
),
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
else:
assert self.self_attn is not None
attn_output, present_key_value, attn_weights = self.self_attn(
attn_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
if self.post_attn_norm is not None:
attn_output = self.post_attn_norm(attn_output)
hidden_states = residual + attn_output
residual = hidden_states
mlp_inputs = (
self.pre_mlp_norm(hidden_states)
if self.pre_mlp_norm is not None
else hidden_states
)
mlp_output = self.mlp(mlp_inputs)
if self.post_mlp_norm is not None:
mlp_output = self.post_mlp_norm(mlp_output)
hidden_states = residual + mlp_output
return hidden_states, present_key_value, attn_weights
class LizzyPreTrainedModel(PreTrainedModel):
config_class = LizzyConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LizzyDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module: nn.Module) -> None:
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (LizzyRMSNorm, nn.LayerNorm)):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike[str] | None,
*model_args: Any,
**kwargs: Any,
) -> "LizzyPreTrainedModel":
model = cast(
"LizzyPreTrainedModel",
super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
),
)
_refresh_attention_rope_buffers(model)
if hasattr(model, "lm_head") and hasattr(model, "model"):
tied_weights_keys = getattr(type(model), "_tied_weights_keys", None)
if isinstance(tied_weights_keys, dict) and tied_weights_keys:
model._tied_weights_keys = dict(tied_weights_keys)
else:
model._tied_weights_keys = {
"lm_head.weight": "model.embed_tokens.weight",
}
model._tp_plan = {"lm_head": "colwise_rep"}
model._pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
return model
def load_state_dict( # type: ignore[override]
self,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
assign: bool = False,
) -> Any:
remapped_state_dict: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
remapped_key = key
if ".mlp.fc_in." in key:
remapped_key = key.replace(".mlp.fc_in.", ".mlp.up_proj.")
elif ".mlp.fc_out." in key:
remapped_key = key.replace(".mlp.fc_out.", ".mlp.down_proj.")
existing = remapped_state_dict.get(remapped_key)
if existing is not None and not torch.equal(existing, value):
msg = (
f"Conflicting legacy Lizzy MLP tensors"
f" for key: {remapped_key}"
)
raise ValueError(msg)
remapped_state_dict[remapped_key] = value
load_result = super().load_state_dict(
remapped_state_dict,
strict=strict,
assign=assign,
)
# RoPE buffers are intentionally non-persistent, so refresh them after
# weight loading instead of trusting constructor-time allocations.
_refresh_attention_rope_buffers(self)
return load_result
class LizzyModel(LizzyPreTrainedModel):
def __init__(self, config: LizzyConfig) -> None:
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
)
self.embed_positions = (
nn.Embedding(config.max_position_embeddings, config.hidden_size)
if config.position_embedding_type == "absolute"
else None
)
self.layers = nn.ModuleList(
LizzyDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
)
self.norm = _make_norm(
config.norm_type,
config.hidden_size,
config.norm_eps,
has_bias=config.norm_has_bias,
)
self.embd_dropout = nn.Dropout(config.embd_dropout)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.embed_tokens = value
def _build_attention_mask(
self,
attention_mask: torch.Tensor | None,
*,
batch_size: int,
q_len: int,
kv_len: int,
kv_offset: int,
cache_position: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
sliding_window: int | None = None,
) -> torch.Tensor:
kv_len = (
int(kv_len.item()) if isinstance(kv_len, torch.Tensor) else int(kv_len)
)
kv_offset = (
int(kv_offset.item())
if isinstance(kv_offset, torch.Tensor)
else int(kv_offset)
)
min_value = torch.finfo(dtype).min
source_positions = cache_position.to(device=device).view(-1, 1)
target_positions = torch.arange(
kv_offset,
kv_offset + kv_len,
device=device,
).unsqueeze(0)
causal = torch.zeros((q_len, kv_len), dtype=dtype, device=device)
causal = causal.masked_fill(target_positions > source_positions, min_value)
if sliding_window is not None:
lower_bound = source_positions - int(sliding_window) + 1
causal = causal.masked_fill(target_positions < lower_bound, min_value)
causal = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
if attention_mask is None:
return causal
if attention_mask.dim() != 2:
msg = "attention_mask must be 2D [batch, sequence]."
raise ValueError(msg)
if attention_mask.shape[1] < kv_len:
pad = torch.ones(
(attention_mask.shape[0], kv_len - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([pad, attention_mask], dim=1)
elif attention_mask.shape[1] > kv_len:
attention_mask = attention_mask[:, -kv_len:]
expanded = attention_mask[:, None, None, :].to(device=device)
padding = (expanded == 0).to(dtype) * min_value
return causal + padding
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | tuple[
tuple[torch.Tensor, torch.Tensor], ...
] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
cache_position: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs: Any,
) -> BaseModelOutputWithPast | tuple[Any, ...]:
if (input_ids is None) == (inputs_embeds is None):
msg = "Exactly one of input_ids or inputs_embeds must be provided."
raise ValueError(msg)
output_attentions = (
bool(output_attentions) if output_attentions is not None else False
)
output_hidden_states = (
bool(output_hidden_states)
if output_hidden_states is not None
else False
)
use_cache = (
bool(use_cache)
if use_cache is not None
else bool(self.config.use_cache)
)
return_dict = bool(return_dict) if return_dict is not None else True
if inputs_embeds is None:
hidden_states = self.embed_tokens(input_ids)
batch_size, seq_len = input_ids.shape
else:
hidden_states = inputs_embeds
batch_size, seq_len, _ = inputs_embeds.shape
cache_object = (
past_key_values
if _is_cache_object(past_key_values)
else None
)
if use_cache and _has_linear_attention(self.config):
# Transformers 5.4 seeds `generate()` with an empty DynamicCache
# for standard causal decoders. Hybrid Lizzy checkpoints need the
# mixed cache below instead, because linear-attention layers read
# DeltaNet convolution/recurrent state during the prefill pass.
if cache_object is not None and not isinstance(
cache_object, LizzyHybridDynamicCache,
):
if int(cache_object.get_seq_length()) > 0:
msg = (
"Hybrid Lizzy checkpoints require "
"LizzyHybridDynamicCache once generation cache "
"state is populated."
)
raise ValueError(msg)
cache_object = LizzyHybridDynamicCache(config=self.config)
past_key_values = cache_object
if use_cache and cache_object is None and past_key_values is None:
if _has_linear_attention(self.config):
# Linear-attention checkpoints need a mixed cache that can hold
# both KV tensors and recurrent DeltaNet state.
cache_object = LizzyHybridDynamicCache(config=self.config)
else:
cache_object = DynamicCache()
past_key_values = cache_object
if cache_object is not None:
past_length = int(cache_object.get_seq_length())
else:
past_length = _legacy_cache_length(past_key_values)
cache_position = _normalize_cache_position(cache_position)
if cache_position is None:
cache_position = torch.arange(
past_length,
past_length + seq_len,
dtype=torch.long,
device=hidden_states.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
if self.embed_positions is not None:
hidden_states = hidden_states + self.embed_positions(position_ids)
hidden_states = self.embd_dropout(hidden_states)
if self.training and self.gradient_checkpointing:
use_cache = False
layer_types = list(self.config.layer_types)
if not layer_types:
layer_types = ["full_attention"] * len(self.layers)
has_linear_attention = any(
str(layer_type) == "linear_attention" for layer_type in layer_types
)
_attn_impl = getattr(self.config, "_attn_implementation", "eager")
if has_linear_attention and isinstance(attention_mask, dict):
linear_attention_mask = attention_mask.get("linear_attention")
else:
linear_attention_mask = attention_mask
if (
has_linear_attention
and cache_object is not None
and getattr(cache_object, "has_previous_state", False)
):
linear_attention_mask = None
elif (
has_linear_attention
and attention_mask is not None
and not isinstance(attention_mask, dict)
and torch.all(attention_mask == 1)
):
linear_attention_mask = None
if (
_attn_impl == "flash_attention_2"
and not isinstance(attention_mask, dict)
):
# Flash attention handles causal masking (via is_causal) and
# padding (via 2D mask) natively; skip building a 4D mask.
attention_mask_mapping = {
lt: attention_mask
for lt in dict.fromkeys(layer_types)
if lt != "linear_attention"
}
elif _attn_impl == "sdpa" and attention_mask is None:
attention_mask_mapping = {}
for layer_type in dict.fromkeys(layer_types):
if layer_type == "linear_attention":
continue
if layer_type == "full_attention":
# Match upstream decoder-only HF models: when SDPA sees
# plain causal full attention with no padding mask to
# preserve, let it use its native is_causal fast-path
# instead of forcing an explicit 4D bias tensor.
attention_mask_mapping[layer_type] = None
continue
layer_idx = layer_types.index(layer_type)
if cache_object is not None:
kv_len, kv_offset = cache_object.get_mask_sizes(
seq_len, layer_idx,
)
else:
kv_len = past_length + seq_len
kv_offset = 0
attention_mask_mapping[layer_type] = self._build_attention_mask(
attention_mask,
batch_size=batch_size,
q_len=seq_len,
kv_len=kv_len,
kv_offset=kv_offset,
cache_position=cache_position,
device=hidden_states.device,
dtype=hidden_states.dtype,
sliding_window=(
self.config.sliding_window
if layer_type == "sliding_attention"
else None
),
)
elif isinstance(attention_mask, dict):
attention_mask_mapping = {
key: value
for key, value in attention_mask.items()
if key != "linear_attention"
}
else:
attention_mask_mapping: dict[str, torch.Tensor] = {}
for layer_type in dict.fromkeys(layer_types):
if layer_type == "linear_attention":
continue
layer_idx = layer_types.index(layer_type)
if cache_object is not None:
kv_len, kv_offset = cache_object.get_mask_sizes(
seq_len, layer_idx,
)
else:
kv_len = past_length + seq_len
kv_offset = 0
attention_mask_mapping[layer_type] = self._build_attention_mask(
attention_mask,
batch_size=batch_size,
q_len=seq_len,
kv_len=kv_len,
kv_offset=kv_offset,
cache_position=cache_position,
device=hidden_states.device,
dtype=hidden_states.dtype,
sliding_window=(
self.config.sliding_window
if layer_type == "sliding_attention"
else None
),
)
all_hidden_states = [] if output_hidden_states else None
all_attentions = [] if output_attentions else None
next_cache = (
cache_object
if cache_object is not None
else ([] if use_cache else None)
)
gradient_checkpointing_func = getattr(
self,
"_gradient_checkpointing_func",
checkpoint,
)
for idx, layer in enumerate(self.layers):
if output_hidden_states and all_hidden_states is not None:
all_hidden_states.append(hidden_states)
layer_type = (
layer_types[idx]
if idx < len(layer_types)
else "full_attention"
)
if layer_type == "linear_attention":
layer_attention_mask = linear_attention_mask
else:
layer_attention_mask = attention_mask_mapping[layer_type]
if cache_object is not None:
layer_past: Cache | tuple[
torch.Tensor, torch.Tensor
] | None = cache_object
elif past_key_values is not None:
layer_past = past_key_values[idx]
if layer_past is not None and layer_past[0] is None:
layer_past = None
else:
layer_past = None
if self.training and self.gradient_checkpointing:
def custom_forward(hidden_states: torch.Tensor) -> Any:
layer_outputs = layer(
hidden_states,
attention_mask=layer_attention_mask,
position_ids=position_ids,
past_key_value=None,
cache_position=cache_position,
use_cache=False,
output_attentions=output_attentions,
**kwargs,
)
if output_attentions:
return layer_outputs[0], layer_outputs[2]
return layer_outputs[0]
checkpointed_outputs = gradient_checkpointing_func(
custom_forward, hidden_states,
)
if output_attentions:
hidden_states, attn_weights = checkpointed_outputs
else:
hidden_states = checkpointed_outputs
attn_weights = None
present = None
else:
hidden_states, present, attn_weights = layer(
hidden_states,
attention_mask=layer_attention_mask,
position_ids=position_ids,
past_key_value=layer_past,
cache_position=cache_position,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
if use_cache and next_cache is not None and cache_object is None:
next_cache.append(present)
if output_attentions and all_attentions is not None:
all_attentions.append(attn_weights)
hidden_states = self.norm(hidden_states)
if output_hidden_states and all_hidden_states is not None:
all_hidden_states.append(hidden_states)
past_key_values_output: Cache | tuple[
tuple[torch.Tensor, torch.Tensor], ...
] | None = None
if use_cache and next_cache is not None:
if cache_object is not None:
past_key_values_output = cache_object
else:
past_key_values_output = tuple(next_cache)
if not return_dict:
output: tuple[Any, ...] = (hidden_states,)
if past_key_values_output is not None:
output = output + (past_key_values_output,)
if output_hidden_states and all_hidden_states is not None:
output = output + (tuple(all_hidden_states),)
if output_attentions and all_attentions is not None:
output = output + (tuple(all_attentions),)
return output
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values_output,
hidden_states=(
tuple(all_hidden_states)
if all_hidden_states is not None
else None
),
attentions=(
tuple(all_attentions)
if all_attentions is not None
else None
),
)
class LizzyForCausalLM(LizzyPreTrainedModel, GenerationMixin):
config_class = LizzyConfig
# Transformers 5.4 expects an expanded target->source mapping here rather than
# the older list-based shorthand.
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: LizzyConfig) -> None:
super().__init__(config)
self.model = LizzyModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings()
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Cache | tuple[
tuple[torch.Tensor, torch.Tensor], ...
] | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Any,
) -> dict[str, Any]:
past_length = 0
if past_key_values is not None:
if _is_cache_object(past_key_values):
past_length = int(past_key_values.get_seq_length())
else:
past_length = _legacy_cache_length(past_key_values)
cache_position = _normalize_cache_position(cache_position)
if cache_position is None:
if past_key_values is not None:
new_tokens = input_ids.shape[1] - past_length
if new_tokens <= 0:
new_tokens = 1
cache_position = torch.arange(
past_length,
past_length + new_tokens,
device=input_ids.device,
)
else:
cache_position = torch.arange(
input_ids.shape[1],
device=input_ids.device,
)
if past_key_values is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
if attention_mask is not None:
attn_mask_idx = (past_length + input_ids.shape[1])
attention_mask = attention_mask[:, -attn_mask_idx :]
if inputs_embeds is not None and past_key_values is None:
model_inputs: dict[str, Any] = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"cache_position": cache_position,
"use_cache": kwargs.get("use_cache", self.config.use_cache),
},
)
return model_inputs
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | tuple[
tuple[torch.Tensor, torch.Tensor], ...
] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Any,
) -> CausalLMOutputWithPast | tuple[Any, ...]:
# HF eval loaders call `forward()` without an explicit return_dict,
# so local Lizzy exports must normalize the optional flag first.
return_dict = bool(return_dict) if return_dict is not None else True
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if labels is not None:
full_logits = self.lm_head(hidden_states)
logits = full_logits[:, slice_indices, :]
else:
full_logits = None
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
shift_logits = full_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
if not return_dict:
output = (logits,) + outputs[1:]
if loss is not None:
output = (loss,) + output
return output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)