Echo-DSRN-114M-Base / modeling_echo.py
mrs83's picture
Upload 10 files
5938d1e verified
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GenerationMixin, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_echo import EchoConfig
try:
from vllm.model_executor.models.transformers import ALL_ATTENTION_FUNCTIONS
except ImportError:
ALL_ATTENTION_FUNCTIONS = {}
try:
from transformers.cache_utils import Cache
except ImportError:
class Cache:
pass
class EchoCache(Cache):
"""
Custom Cache to prevent Hugging Face's DynamicCache from dropping
the (k_attn, v_attn) elements from the DSRN 4-tuple state.
"""
def __init__(self, states=None):
self.states = states if states is not None else []
def get_seq_length(self, layer_idx=0):
if not self.states or len(self.states) <= layer_idx:
return 0
state = self.states[layer_idx]
if len(state) == 4:
return state[2].shape[2]
return 0
def get_max_length(self):
return None
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# EchoModel handles its own cache updates internally within the blocks.
# This update method is just a shim to satisfy the Cache protocol.
# k, v are already updated in the state tuple returned by the block.
if len(self.states) > layer_idx:
state = self.states[layer_idx]
if len(state) == 4:
return state[2], state[3]
return key_states, value_states
def get_usable_length(self, new_seq_length, layer_idx=0):
return self.get_seq_length(layer_idx)
def __getitem__(self, idx):
return self.states[idx]
def __len__(self):
return len(self.states)
def __iter__(self):
return iter(self.states)
def reorder_cache(self, beam_idx: torch.LongTensor):
reordered_states = []
for layer_state in self.states:
reordered_layer_state = tuple(
tensor.index_select(0, beam_idx.to(tensor.device)) for tensor in layer_state
)
reordered_states.append(reordered_layer_state)
self.states = reordered_states
# --- STANDALONE KERNELS (AUTOMAGICALLY INLINED) ---
def _sequential_scan(a, b, h):
"""
Core sequential scan for a batch of sequences.
Vectorized across all dimensions except time.
"""
a.shape[:-1]
a.shape[-1]
# a, b: (..., T, D)
# h: (..., D)
T = a.shape[-2]
res = torch.empty_like(b)
curr_h = h
for t in range(T):
curr_h = a[..., t, :] * curr_h + b[..., t, :]
res[..., t, :] = curr_h
return res, curr_h
def dsrn_parallel_scan(g_t, m_t, c_0=None, chunk_size=32, use_triton=False):
"""
Parallel implementation of the DSRN slow-state update:
c_t = (1 - g_t) * c_{t-1} + g_t * m_t
Uses a Hierarchical Chunked Scan for O(T/K + K) speed and stability,
or a custom Triton kernel for dramatically reduced memory bandwidth.
"""
# Global Override: Disabling Triton scan while debugging LoRA NaN gradients
if use_triton and g_t.is_cuda:
try:
from .triton_scan import triton_dsrn_parallel_scan
return triton_dsrn_parallel_scan(g_t, m_t, c_0)
except ImportError:
import warnings
warnings.warn("Triton scan unavailable. Falling back to PyTorch scan.", UserWarning)
orig_dtype = g_t.dtype
a = (1.0 - g_t).float()
b = (g_t * m_t).float()
B, T, D = a.shape
device = a.device
# Pad T to be multiple of chunk_size
pad_len = (chunk_size - (T % chunk_size)) % chunk_size
if pad_len > 0:
a = F.pad(a, (0, 0, 0, pad_len), value=1.0)
b = F.pad(b, (0, 0, 0, pad_len), value=0.0)
new_T = T + pad_len
num_chunks = new_T // chunk_size
# 1. Reshape to (B, num_chunks, chunk_size, D)
a_chunks = a.view(B, num_chunks, chunk_size, D)
b_chunks = b.view(B, num_chunks, chunk_size, D)
# 2. Local scan within each chunk (vectorized across B and num_chunks)
h_init_local = torch.zeros(B, num_chunks, D, device=device, dtype=torch.float32)
c_res, c_final = _sequential_scan(a_chunks, b_chunks, h_init_local)
# Summary of a for each chunk (product of a)
a_final = torch.prod(a_chunks, dim=2) # (B, num_chunks, D)
# 3. Global scan across chunk summaries
h_0 = c_0.float() if c_0 is not None else torch.zeros(B, D, device=device, dtype=torch.float32)
# h_chunk_outputs[:, j] is the state AFTER chunk j.
h_chunk_outputs, _ = _sequential_scan(a_final, c_final, h_0)
# The state BEFORE chunk j is h_chunk_outputs[:, j-1].
h_starts = torch.cat([h_0.unsqueeze(1), h_chunk_outputs[:, :-1]], dim=1)
# 4. Final combine: h_{j, i} = a_prefix_{j, i} * h_starts[j] + c_res[j, i]
a_prefix = torch.cumprod(a_chunks, dim=2)
final_h = a_prefix * h_starts.unsqueeze(2) + c_res
# Reshape back and crop, then cast back to original dtype
return final_h.view(B, -1, D)[:, :T].to(orig_dtype)
def rms_norm_fn(hidden_states, weight, eps=1e-6):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.contiguous().to(torch.float32)
variance = (hidden_states * hidden_states).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)
def dsrn_parallel_kernel_legacy(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Legacy DSRN kernel (Fixed LayerNorm, No Surprise Read).
Identical to the version that passed verification.
"""
B, T, D = x.shape
# 1. Norm and Projections
x_norm = F.layer_norm(
x,
(D,),
weight=model_block.norm_fast.weight,
bias=model_block.norm_fast.bias,
)
# Fast State Path (Scan)
gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
z_all = torch.sigmoid(gru_proj[:, :, :D])
r_all = torch.tanh(gru_proj[:, :, 2 * D :]) # Optimization: slice instead of chunk
# --- EOS RESET LOGIC (Fast State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = (
0 # First token reset depends on previous chunk eos, handled by h_prev/c_prev passing 0
)
# Apply strict reset to z_all
z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
# h_t = (1 - z_t) * h_{t-1} + z_t * r_t
h_all = dsrn_parallel_scan(
z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
)
h_new = h_all[:, -1]
# 2. Slow State Path
# CAUSAL SHIFT: Predict x[t] using h[t-1]
# h_all is [h_1, ..., h_T]. We need [h_0, ..., h_{T-1}]
# Prepend h_prev to shift
h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
x_pred = model_block.linear_pred(h_shifted)
diff = x - x_pred
error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
surprise_signal = error * model_block.surprise_lambda
# Gates
gate_logits = model_block.linear_gate(h_all) + surprise_signal
g_all = torch.sigmoid(gate_logits)
m_all = torch.tanh(model_block.linear_memory(h_all))
# --- EOS RESET LOGIC (Slow State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
# c_t
c_all = dsrn_parallel_scan(
g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
)
c_new = c_all[:, -1]
# --- Inter-Chunk Reset ---
# If the LAST token is EOS, then h_new/c_new (which are states FOR NEXT CHUNK) must be 0.
if eos_mask is not None:
last_is_eos = eos_mask[:, -1].float() # (B,)
keep_prob = (1.0 - last_is_eos).unsqueeze(-1) # (B, 1)
h_new = h_new * keep_prob
c_new = c_new * keep_prob
gate_stats = g_all.mean(dim=-1)
# 3. Final MLP Path
h_norm = F.layer_norm(
h_all, (D,), weight=model_block.norm_ff.weight, bias=model_block.norm_ff.bias
)
mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
x_out = x + mlp_out
# Continuous Read (Surprise Gate Fix)
# Enabled on Legacy to fix Disconnected Slow State bug while keeping LayerNorm
x_out = x_out + model_block.linear_read(c_all)
return x_out, h_new, c_new, gate_stats
def dsrn_parallel_kernel_hybrid(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Hybrid DSRN kernel (RMSNorm + Surprise Read).
"""
B, T, D = x.shape
# 1. Norm (RMSNorm hardcoded for Hybrid path)
x_norm = rms_norm_fn(x, model_block.norm_fast.weight)
# Fast State
gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
z_all = torch.sigmoid(gru_proj[:, :, :D])
r_all = torch.tanh(gru_proj[:, :, 2 * D :])
# --- EOS RESET LOGIC (Fast State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
h_all = dsrn_parallel_scan(
z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
)
h_new = h_all[:, -1]
# 2. Slow State
# CAUSAL SHIFT: Predict x[t] using h[t-1]
h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
x_pred = model_block.linear_pred(h_shifted)
diff = x - x_pred
error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
surprise_signal = error * model_block.surprise_lambda
gate_logits = model_block.linear_gate(h_all) + surprise_signal
g_all = torch.sigmoid(gate_logits)
m_all = torch.tanh(model_block.linear_memory(h_all))
# --- EOS RESET LOGIC (Slow State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
c_all = dsrn_parallel_scan(
g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
)
c_new = c_all[:, -1]
# --- Inter-Chunk Reset ---
if eos_mask is not None:
last_is_eos = eos_mask[:, -1].float()
keep_prob = (1.0 - last_is_eos).unsqueeze(-1)
h_new = h_new * keep_prob
c_new = c_new * keep_prob
gate_stats = g_all.mean(dim=-1)
# 3. Final MLP
h_norm = rms_norm_fn(h_all, model_block.norm_ff.weight)
mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
x_out = x + mlp_out
# Continuous Read (Hybrid Feature)
if model_block.use_hybrid_attention:
x_out = x_out + model_block.linear_read(c_all)
return x_out, h_new, c_new, gate_stats
def dsrn_parallel_kernel(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Wrapper for backward compatibility. Dispatches based on config.
"""
if getattr(model_block, "use_rmsnorm", False):
return dsrn_parallel_kernel_hybrid(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
return dsrn_parallel_kernel_legacy(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
class HymbaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
HymbaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class EchoRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.device = device
# We NO LONGER use buffers here because they are being corrupted by
# Hugging Face's weight loading mechanism for this specific model.
# We will compute and move them on the first forward pass.
self._cos_cached = None
self._sin_cached = None
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
# Compute inv_freq locally
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, x, seq_len=None):
if (
self._cos_cached is None
or seq_len > self.max_seq_len_cached
or self._cos_cached.device != x.device
):
self._set_cos_sin_cache(
seq_len=max(seq_len, self.max_position_embeddings), device=x.device, dtype=x.dtype
)
return (
self._cos_cached[:seq_len].to(dtype=x.dtype),
self._sin_cached[:seq_len].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SlidingWindowAttention(nn.Module):
def __init__(self, config: EchoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.window_size = getattr(config, "window_size", 128)
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = EchoRotaryEmbedding(
self.head_dim,
base=getattr(config, "rope_theta", 10000.0),
)
def forward(
self,
x,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
B, T, C = x.shape
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# --- RoPE Injection ---
if position_ids is None:
# Fallback if position_ids was not passed
seq_length_with_past = T
if past_key_values is not None:
seq_length_with_past += past_key_values[0].shape[2]
position_ids = (
torch.arange(
seq_length_with_past - T,
seq_length_with_past,
dtype=torch.long,
device=x.device,
)
.unsqueeze(0)
.view(-1, T)
)
kv_seq_len = k.shape[2]
if past_key_values is not None:
kv_seq_len += past_key_values[0].shape[2]
cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
# ----------------------
if past_key_values is not None:
k_past, v_past = past_key_values
k = torch.cat([k_past, k], dim=2)
v = torch.cat([v_past, v], dim=2)
# The cache MUST store the full history, do not overwrite it with truncated slices
current_key_value = (k, v)
# Create slices for attention computation
k_attn = k
v_attn = v
# Enforce Sliding Window (Truncate oldest tokens for attention ONLY)
if self.window_size is not None and k_attn.shape[2] > self.window_size:
k_attn = k_attn[:, :, -self.window_size :, :]
v_attn = v_attn[:, :, -self.window_size :, :]
attn_fn = ALL_ATTENTION_FUNCTIONS.get(
kwargs.get("attn_implementation", "sdpa"), F.scaled_dot_product_attention
)
# Determining causality and windowing:
# 1. Training (T > 1): Use sliding window causal mask.
# 2. Decoding (T = 1): Use sliding window and NO CAUSAL MASK
if T > 1:
# Training/Prefill: Attend to full k, v but apply band-limited causal mask
# Build sliding window causal mask (T, T)
mask = torch.full((T, T), float("-inf"), device=x.device, dtype=x.dtype)
mask = torch.triu(mask, diagonal=1) # Causal upper triangle = -inf
# Keep tokens in range [i - window_size, i]
row_idx = torch.arange(T, device=x.device).view(-1, 1)
col_idx = torch.arange(T, device=x.device).view(1, -1)
mask = torch.where((row_idx - col_idx) >= self.window_size, float("-inf"), mask)
# Replace -inf with 0 for the permitted window (float mask expected by sdpa)
mask = torch.where(mask == float("-inf"), mask, torch.zeros_like(mask))
y = attn_fn(q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0))
else:
# Decoding: Recurrent step, attend only to the last window_size tokens
y = attn_fn(q, k_attn, v_attn, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y), current_key_value
class DSRNBlock(nn.Module):
def __init__(self, config: EchoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.state_size = config.hidden_size * config.num_heads
self.use_triton = getattr(config, "use_triton", True)
self.use_hybrid_attention = getattr(config, "use_hybrid_attention", True)
self.use_rmsnorm = getattr(config, "use_rmsnorm", True)
# Fast State (GRU)
if self.use_rmsnorm:
self.norm_fast = HymbaRMSNorm(config.hidden_size)
else:
self.norm_fast = nn.LayerNorm(config.hidden_size)
self.gru_cell = nn.GRUCell(config.hidden_size, config.hidden_size)
# Hybrid Attention
if self.use_hybrid_attention:
self.attn = SlidingWindowAttention(config)
# Slow State (DSRN)
self.linear_read = nn.Linear(self.state_size, config.hidden_size, bias=False)
self.linear_gate = nn.Linear(config.hidden_size, self.state_size)
self.linear_memory = nn.Linear(config.hidden_size, self.state_size)
# -- Surprise Mechanism --
self.linear_pred = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.surprise_lambda = nn.Parameter(torch.zeros(self.state_size))
# Feed-Forward
if self.use_rmsnorm:
self.norm_ff = HymbaRMSNorm(config.hidden_size)
else:
self.norm_ff = nn.LayerNorm(config.hidden_size)
# Simple MLP: Linear -> GELU -> Linear
# mlp_up / mlp_act / mlp_down are the ONLY registered submodules.
# No self.mlp alias — that caused double-registration and spurious "missing keys".
intermediate_size = getattr(
config, "intermediate_size", int(config.hidden_size * getattr(config, "mlp_ratio", 4.0))
)
self.mlp_up = nn.Linear(config.hidden_size, intermediate_size)
self.mlp_act = nn.GELU()
self.mlp_down = nn.Linear(intermediate_size, config.hidden_size)
def forward(
self, x: torch.Tensor, state_prev: Tuple[torch.Tensor, ...], **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# Unpack state
# Supports (h, c) or (h, c, k_attn, v_attn)
h_prev = state_prev[0]
c_prev = state_prev[1]
if self.use_triton and x.is_cuda:
# Placeholder for Triton
pass
# Use Parallel Kernel
x_out, h_new, c_new, gate_stats = dsrn_parallel_kernel(self, x, h_prev, c_prev)
if self.use_hybrid_attention:
# Re-apply norm for attention branch (cleanest for surgical transplant)
x_norm = self.norm_fast(x)
# Extract attention state from tuple if present (h, c, k_attn, v_attn)
# HF state structure is now: (h, c, k_attn, v_attn)
# But wait, past_key_values in forward loop is just (h,c) from legacy code.
# We need to expand the state tuple to include attention KV.
attn_kv = None
if len(state_prev) == 4:
attn_kv = (state_prev[2], state_prev[3])
attn_out, new_attn_kv = self.attn(x_norm, past_key_values=attn_kv, **kwargs)
x_out = x_out + attn_out
# Update state with new KV
if new_attn_kv is not None:
h_new_full = (h_new, c_new, new_attn_kv[0], new_attn_kv[1])
else:
h_new_full = (h_new, c_new)
else:
h_new_full = (h_new, c_new)
return x_out, h_new_full
class EchoPreTrainedModel(PreTrainedModel):
config_class = EchoConfig
base_model_prefix = "model"
_no_split_modules = ["DSRNBlock"]
# Silently drop legacy mlp.0.*/mlp.1.*/mlp.2.* alias keys if they exist in old
# local training checkpoints from before the self.mlp aliasing was removed.
# The canonical names are mlp_up.* / mlp_act.* / mlp_down.* which load fine.
_keys_to_ignore_on_load_unexpected = [
r".*\.mlp\.0\..*",
r".*\.mlp\.1\..*",
r".*\.mlp\.2\..*",
]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
class EchoModel(EchoPreTrainedModel):
supports_gradient_checkpointing = True
_supports_attention_backend = True
def __init__(self, config: EchoConfig):
super().__init__(config)
self.embed_dim = config.embed_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.state_dim = config.embed_dim * config.num_heads
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.blocks = nn.ModuleList([DSRNBlock(config) for _ in range(config.num_layers)])
if getattr(config, "use_rmsnorm", False):
self.final_norm = HymbaRMSNorm(config.hidden_size)
else:
self.final_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False
self.post_init()
# --- ZOMBIE GRADIENT PATCH (FIXED) ---
# Fixed: Now using controlled bias (0.0) and Zero-Init Residuals
bias_val = getattr(config, "gate_bias_init", 0.0)
for block in self.blocks:
nn.init.constant_(block.linear_gate.bias, bias_val)
# Init Surprise
if (
block.linear_pred.weight.dtype in (torch.bfloat16, torch.float16)
and block.linear_pred.weight.is_cuda
):
_device = block.linear_pred.weight.device
_dtype = block.linear_pred.weight.dtype
temp_w = torch.empty_like(
block.linear_pred.weight, dtype=torch.float32, device="cpu"
)
nn.init.orthogonal_(temp_w, gain=0.1)
with torch.no_grad():
block.linear_pred.weight.copy_(temp_w.to(device=_device, dtype=_dtype))
else:
nn.init.orthogonal_(block.linear_pred.weight, gain=0.1)
nn.init.zeros_(block.surprise_lambda)
# CRITICAL: Zero-Init Residual Output (Identity Start)
nn.init.zeros_(block.mlp_down.weight)
nn.init.zeros_(block.mlp_down.bias)
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Enable/disable gradient checkpointing."""
self.gradient_checkpointing = enable
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_len = input_ids.shape
x = self.embedding(input_ids)
elif inputs_embeds is not None:
batch_size, seq_len, _ = inputs_embeds.shape
x = inputs_embeds
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = x.device
# Initialize states if not provided or if it's an empty Cache object
is_empty_cache = (
hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0
)
if past_key_values is None or is_empty_cache:
past_key_values = []
for _ in range(self.num_layers):
h = torch.zeros(batch_size, self.embed_dim, device=device, dtype=x.dtype)
c = torch.zeros(batch_size, self.state_dim, device=device, dtype=x.dtype)
past_key_values.append((h, c))
current_states = past_key_values
next_states = []
# Layer-Major Execution
for i, block in enumerate(self.blocks):
# Handle potential DynamicCache structure or list of tuples
if hasattr(current_states, "__getitem__"):
state_i = current_states[i]
else:
state_i = current_states[i]
if len(state_i) == 2:
# DSRN Only
pass
elif len(state_i) == 4:
# DSRN + Attention State
pass
else:
# Fallback for empty/malformed states
h_prev = torch.zeros(batch_size, self.embed_dim, device=device)
c_prev = torch.zeros(batch_size, self.state_dim, device=device)
state_i = (h_prev, c_prev)
# Use gradient checkpointing if enabled
if self.gradient_checkpointing and self.training:
# Checkpointing complex states is tricky, usually just pass h/c
x, h_new_full = torch.utils.checkpoint.checkpoint(
block, x, state_i, use_reentrant=False
)
else:
x, h_new_full = block(x, state_i, **kwargs)
next_states.append(h_new_full)
x = self.final_norm(x)
if EchoCache is not None:
next_states = EchoCache(next_states)
return x, next_states
class EchoForCausalLM(EchoPreTrainedModel, GenerationMixin):
_is_causal = True
supports_gradient_checkpointing = True
_supports_cache_class = False
_supports_static_cache = False
main_input_name = "input_ids"
def __init__(self, config: EchoConfig):
super().__init__(config)
self.model = EchoModel(config)
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Enable/disable gradient checkpointing."""
self.model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else getattr(self.config, "output_hidden_states", False)
)
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", True)
return_dict = (
return_dict
if return_dict is not None
else getattr(self.config, "use_return_dict", True)
)
'''
If kwargs is getting overloaded with extra args HF generate passes,
we safely extract kwargs here.
'''
# Pass position_ids explicitly alongside **kwargs
kwargs["position_ids"] = position_ids
hidden_states, new_states = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits, new_states)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_states if use_cache else None,
hidden_states=None, # EchoModel doesn't expose internal states yet
attentions=None, # EchoModel doesn't expose attention weights yet
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, **kwargs
):
# If past_key_values is a DynamicCache, we need to extract the underlying list of tuples
# if the custom cache hasn't taken over yet. But actually, HF doesn't know about our 4-tuples.
# So we should just let EchoModel handle it. If HF gave us a DynamicCache, it might be empty
# or mangled.
if (
past_key_values is not None
and not isinstance(past_key_values, (list, tuple))
and not isinstance(past_key_values, EchoCache)
):
# It's a DynamicCache. It's likely from the first generation step.
# We can't use it directly because it stripped our (h,c).
# But wait, on the VERY first generation step, past_key_values is None, then EchoModel returns EchoCache.
# On subsequent steps we get EchoCache.
# So if we get a DynamicCache, it means someone passed past_key_values explicitly to generate(),
# or HF auto-created it on step 0 and passed it to step 1 incorrectly.
pass
# In newer transformers, past_key_values could be a DynamicCache.
# Check if it's effectively empty.
is_empty = False
if past_key_values is None:
is_empty = True
elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0:
is_empty = True
elif isinstance(past_key_values, list) and len(past_key_values) == 0:
is_empty = True
# If past_key_values is used, we only need the last token
if not is_empty:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache": kwargs.get("use_cache"),
}
def _reorder_cache(self, past_key_values, beam_idx):
"""
Reorders cache for beam search or contrastive search.
past_key_values: List[Tuple(h, c, ...)]
"""
if past_key_values is None:
return None
reordered_past = []
for layer_past in past_key_values:
# Each layer_past is a tuple of tensors (h, c) or (h, c, k, v)
reordered_layer_past = tuple(
p.index_select(0, beam_idx.to(p.device)) for p in layer_past
)
reordered_past.append(reordered_layer_past)
return reordered_past