mrs83's picture
Upload folder using huggingface_hub
b158f2b verified
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GenerationMixin, PreTrainedModel
from transformers.modeling_outputs import (
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from .configuration_echo import EchoConfig
if TYPE_CHECKING:
# Force HF trust_remote_code AST parser to bundle triton_scan.py
pass
try:
# pyrefly: ignore [missing-import]
from vllm.model_executor.models.transformers import ALL_ATTENTION_FUNCTIONS
except ImportError:
ALL_ATTENTION_FUNCTIONS = {}
try:
from transformers.cache_utils import Cache
except ImportError:
class Cache:
pass
class EchoCache(Cache):
"""
Custom Cache to prevent Hugging Face's DynamicCache from dropping
the (k_attn, v_attn) elements from the DSRN 4-tuple state.
"""
def __init__(self, states=None):
self.states = states if states is not None else []
self.layers = self.states # HF expectation
@property
def is_compileable(self):
return False
def get_seq_length(self, layer_idx=0):
if not self.states or len(self.states) <= layer_idx:
return 0
state = self.states[layer_idx]
if len(state) == 4:
return state[2].shape[2]
return 0
def get_max_length(self):
return None
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# EchoModel handles its own cache updates internally within the blocks.
# This update method is just a shim to satisfy the Cache protocol.
# k, v are already updated in the state tuple returned by the block.
if len(self.states) > layer_idx:
state = self.states[layer_idx]
if len(state) == 4:
return state[2], state[3]
return key_states, value_states
def get_usable_length(self, new_seq_length, layer_idx=0):
return self.get_seq_length(layer_idx)
def __getitem__(self, idx):
return self.states[idx]
def __len__(self):
return len(self.states)
def __iter__(self):
return iter(self.states)
def reorder_cache(self, beam_idx: torch.LongTensor):
reordered_states = []
for layer_state in self.states:
reordered_layer_state = tuple(
tensor.index_select(0, beam_idx.to(tensor.device)) for tensor in layer_state
)
reordered_states.append(reordered_layer_state)
self.states = reordered_states
# --- STANDALONE KERNELS (AUTOMAGICALLY INLINED) ---
def _sequential_scan(a, b, h):
"""
Core sequential scan for a batch of sequences.
Vectorized across all dimensions except time.
"""
a.shape[:-1]
a.shape[-1]
# a, b: (..., T, D)
# h: (..., D)
T = a.shape[-2]
res = torch.empty_like(b)
curr_h = h
for t in range(T):
curr_h = a[..., t, :] * curr_h + b[..., t, :]
res[..., t, :] = curr_h
return res, curr_h
def dsrn_parallel_scan(g_t, m_t, c_0=None, chunk_size=32, use_triton=False):
"""
Parallel implementation of the DSRN slow-state update:
c_t = (1 - g_t) * c_{t-1} + g_t * m_t
Uses a Hierarchical Chunked Scan for O(T/K + K) speed and stability,
or a custom Triton kernel for dramatically reduced memory bandwidth.
"""
# Global Override: Disabling Triton scan while debugging LoRA NaN gradients
if use_triton and g_t.is_cuda:
try:
from .triton_scan import triton_dsrn_parallel_scan
return triton_dsrn_parallel_scan(g_t, m_t, c_0)
except ImportError:
import warnings
warnings.warn("Triton scan unavailable. Falling back to PyTorch scan.", UserWarning)
orig_dtype = g_t.dtype
a = (1.0 - g_t).float()
b = (g_t * m_t).float()
B, T, D = a.shape
device = a.device
# Pad T to be multiple of chunk_size
pad_len = (chunk_size - (T % chunk_size)) % chunk_size
if pad_len > 0:
a = F.pad(a, (0, 0, 0, pad_len), value=1.0)
b = F.pad(b, (0, 0, 0, pad_len), value=0.0)
new_T = T + pad_len
num_chunks = new_T // chunk_size
# 1. Reshape to (B, num_chunks, chunk_size, D)
a_chunks = a.view(B, num_chunks, chunk_size, D)
b_chunks = b.view(B, num_chunks, chunk_size, D)
# 2. Local scan within each chunk (vectorized across B and num_chunks)
h_init_local = torch.zeros(B, num_chunks, D, device=device, dtype=torch.float32)
c_res, c_final = _sequential_scan(a_chunks, b_chunks, h_init_local)
# Summary of a for each chunk (product of a)
a_final = torch.prod(a_chunks, dim=2) # (B, num_chunks, D)
# 3. Global scan across chunk summaries
h_0 = c_0.float() if c_0 is not None else torch.zeros(B, D, device=device, dtype=torch.float32)
# h_chunk_outputs[:, j] is the state AFTER chunk j.
h_chunk_outputs, _ = _sequential_scan(a_final, c_final, h_0)
# The state BEFORE chunk j is h_chunk_outputs[:, j-1].
h_starts = torch.cat([h_0.unsqueeze(1), h_chunk_outputs[:, :-1]], dim=1)
# 4. Final combine: h_{j, i} = a_prefix_{j, i} * h_starts[j] + c_res[j, i]
a_prefix = torch.cumprod(a_chunks, dim=2)
final_h = a_prefix * h_starts.unsqueeze(2) + c_res
# Reshape back and crop, then cast back to original dtype
return final_h.view(B, -1, D)[:, :T].to(orig_dtype)
def rms_norm_fn(hidden_states, weight, eps=1e-6):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.contiguous().to(torch.float32)
variance = (hidden_states * hidden_states).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)
def dsrn_parallel_kernel_legacy(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Legacy DSRN kernel (Fixed LayerNorm, No Surprise Read).
Identical to the version that passed verification.
"""
B, T, D = x.shape
# 1. Norm and Projections
x_norm = F.layer_norm(
x,
(D,),
weight=model_block.norm_fast.weight,
bias=model_block.norm_fast.bias,
)
# Fast State Path (Scan)
gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
z_all = torch.sigmoid(gru_proj[:, :, :D])
r_all = torch.tanh(gru_proj[:, :, 2 * D :]) # Optimization: slice instead of chunk
# --- EOS RESET LOGIC (Fast State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = (
0 # First token reset depends on previous chunk eos, handled by h_prev/c_prev passing 0
)
# Apply strict reset to z_all
z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
# h_t = (1 - z_t) * h_{t-1} + z_t * r_t
h_all = dsrn_parallel_scan(
z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
)
h_new = h_all[:, -1]
# 2. Slow State Path
# CAUSAL SHIFT: Predict x[t] using h[t-1]
# h_all is [h_1, ..., h_T]. We need [h_0, ..., h_{T-1}]
# Prepend h_prev to shift
h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
x_pred = model_block.linear_pred(h_shifted)
diff = x - x_pred
error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
# Constrain surprise_lambda strictly positive to guarantee error opens the memory gate
surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda)
# Gates
gate_logits = model_block.linear_gate(h_all) + surprise_signal
g_all = torch.sigmoid(gate_logits)
m_all = torch.tanh(model_block.linear_memory(h_all))
# --- EOS RESET LOGIC (Slow State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
# c_t
c_all = dsrn_parallel_scan(
g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
)
c_new = c_all[:, -1]
# --- Inter-Chunk Reset ---
# If the LAST token is EOS, then h_new/c_new (which are states FOR NEXT CHUNK) must be 0.
if eos_mask is not None:
last_is_eos = eos_mask[:, -1].float() # (B,)
keep_prob = (1.0 - last_is_eos).unsqueeze(-1) # (B, 1)
h_new = h_new * keep_prob
c_new = c_new * keep_prob
gate_stats = g_all.mean(dim=-1)
# 3. Final MLP Path
h_norm = F.layer_norm(
h_all, (D,), weight=model_block.norm_ff.weight, bias=model_block.norm_ff.bias
)
mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
x_out = x + mlp_out
# Continuous Read (Surprise Gate Fix)
# Enabled on Legacy to fix Disconnected Slow State bug while keeping LayerNorm
x_out = x_out + model_block.linear_read(c_all)
return x_out, h_new, c_new, gate_stats
def dsrn_parallel_kernel_hybrid(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Hybrid DSRN kernel (RMSNorm + Surprise Read).
"""
B, T, D = x.shape
# 1. Norm (RMSNorm hardcoded for Hybrid path)
x_norm = rms_norm_fn(x, model_block.norm_fast.weight)
# Fast State
gru_proj = F.linear(x_norm, model_block.gru_cell.weight_ih, model_block.gru_cell.bias_ih)
z_all = torch.sigmoid(gru_proj[:, :, :D])
r_all = torch.tanh(gru_proj[:, :, 2 * D :])
# --- EOS RESET LOGIC (Fast State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
z_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.ones_like(z_all), z_all)
h_all = dsrn_parallel_scan(
z_all, r_all, h_prev, use_triton=getattr(model_block, "use_triton", False)
)
h_new = h_all[:, -1]
# 2. Slow State
# CAUSAL SHIFT: Predict x[t] using h[t-1]
h_shifted = torch.cat([h_prev.unsqueeze(1), h_all[:, :-1, :]], dim=1)
x_pred = model_block.linear_pred(h_shifted)
diff = x - x_pred
error = torch.clamp(diff * diff, max=10.0).mean(dim=-1, keepdim=True)
# Constrain surprise_lambda strictly positive to guarantee error opens the memory gate
surprise_signal = error * torch.nn.functional.softplus(model_block.surprise_lambda)
gate_logits = model_block.linear_gate(h_all) + surprise_signal
g_all = torch.sigmoid(gate_logits)
m_all = torch.tanh(model_block.linear_memory(h_all))
# --- EOS RESET LOGIC (Slow State) ---
if eos_mask is not None:
reset_mask = torch.roll(eos_mask, shifts=1, dims=1)
reset_mask[:, 0] = 0
g_all = torch.where(reset_mask.unsqueeze(-1) > 0, torch.zeros_like(g_all), g_all)
c_all = dsrn_parallel_scan(
g_all, m_all, c_prev, use_triton=getattr(model_block, "use_triton", False)
)
c_new = c_all[:, -1]
# --- Inter-Chunk Reset ---
if eos_mask is not None:
last_is_eos = eos_mask[:, -1].float()
keep_prob = (1.0 - last_is_eos).unsqueeze(-1)
h_new = h_new * keep_prob
c_new = c_new * keep_prob
gate_stats = g_all.mean(dim=-1)
# 3. Final MLP
h_norm = rms_norm_fn(h_all, model_block.norm_ff.weight)
mlp_out = model_block.mlp_down(model_block.mlp_act(model_block.mlp_up(h_norm)))
x_out = x + mlp_out
# Continuous Read (Hybrid Feature)
if model_block.use_hybrid_attention:
x_out = x_out + model_block.linear_read(c_all)
return x_out, h_new, c_new, gate_stats
def dsrn_parallel_kernel(
model_block: nn.Module,
x: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
eos_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Wrapper for backward compatibility. Dispatches based on config.
"""
if getattr(model_block, "use_rmsnorm", False):
return dsrn_parallel_kernel_hybrid(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
return dsrn_parallel_kernel_legacy(model_block, x, h_prev, c_prev, eos_mask=eos_mask)
class HymbaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
HymbaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class EchoRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.device = device
# We NO LONGER use buffers here because they are being corrupted by
# Hugging Face's weight loading mechanism for this specific model.
# We will compute and move them on the first forward pass.
self._cos_cached = None
self._sin_cached = None
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
# Compute inv_freq locally
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, x, seq_len=None):
if (
self._cos_cached is None
or seq_len > self.max_seq_len_cached
or self._cos_cached.device != x.device
):
self._set_cos_sin_cache(
seq_len=max(seq_len, self.max_position_embeddings), device=x.device, dtype=x.dtype
)
return (
self._cos_cached[:seq_len].to(dtype=x.dtype),
self._sin_cached[:seq_len].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # (B, 1, T, D)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SlidingWindowAttention(nn.Module):
def __init__(self, config: EchoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.window_size = getattr(config, "window_size", 128)
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = EchoRotaryEmbedding(
self.head_dim,
base=getattr(config, "rope_theta", 10000.0),
)
def forward(
self,
x,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
B, T, C = x.shape
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# --- RoPE Injection ---
if position_ids is None:
# Fallback if position_ids was not passed
seq_length_with_past = T
if past_key_values is not None:
seq_length_with_past += past_key_values[0].shape[2]
position_ids = (
torch.arange(
seq_length_with_past - T,
seq_length_with_past,
dtype=torch.long,
device=x.device,
)
.unsqueeze(0)
.view(-1, T)
)
kv_seq_len = k.shape[2]
if past_key_values is not None:
kv_seq_len += past_key_values[0].shape[2]
cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
# ----------------------
if past_key_values is not None:
k_past, v_past = past_key_values
k = torch.cat([k_past, k], dim=2)
v = torch.cat([v_past, v], dim=2)
# The cache MUST store the full history, do not overwrite it with truncated slices
current_key_value = (k, v)
# Create slices for attention computation
k_attn = k
v_attn = v
# Enforce Sliding Window (Truncate oldest tokens for attention ONLY)
if self.window_size is not None and k_attn.shape[2] > self.window_size:
k_attn = k_attn[:, :, -self.window_size :, :]
v_attn = v_attn[:, :, -self.window_size :, :]
attn_fn = ALL_ATTENTION_FUNCTIONS.get(
kwargs.get("attn_implementation", "sdpa"), F.scaled_dot_product_attention
)
# Determining causality and windowing:
# 1. Training (T > 1): Use sliding window causal mask.
# 2. Decoding (T = 1): Use sliding window and NO CAUSAL MASK
if T > 1:
# Training/Prefill: Attend to full k, v but apply band-limited causal mask
# Build sliding window causal mask (T, kv_seq_len)
kv_all_seq_len = k.shape[2]
past_seq_len = kv_all_seq_len - T
mask = torch.zeros((T, kv_all_seq_len), device=x.device, dtype=x.dtype)
row_idx = torch.arange(T, device=x.device).view(-1, 1)
col_idx = torch.arange(kv_all_seq_len, device=x.device).view(1, -1)
abs_pos = row_idx + past_seq_len
# Causal upper triangle = -inf
mask = torch.where(col_idx > abs_pos, float("-inf"), mask)
# Keep tokens in range [abs_pos - self.window_size, abs_pos]
if self.window_size is not None:
mask = torch.where((abs_pos - col_idx) >= self.window_size, float("-inf"), mask)
# Replace -inf with 0 for the permitted window (float mask expected by sdpa)
mask = torch.where(mask == float("-inf"), mask, torch.zeros_like(mask))
y = attn_fn(q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0))
else:
# Decoding: Recurrent step, attend only to the last window_size tokens
y = attn_fn(q, k_attn, v_attn, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y), current_key_value
class DSRNBlock(nn.Module):
def __init__(self, config: EchoConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.state_size = config.hidden_size * config.num_heads
self.use_triton = getattr(config, "use_triton", True)
self.use_hybrid_attention = getattr(config, "use_hybrid_attention", True)
self.use_rmsnorm = getattr(config, "use_rmsnorm", True)
# Fast State (GRU)
if self.use_rmsnorm:
self.norm_fast = HymbaRMSNorm(config.hidden_size)
else:
self.norm_fast = nn.LayerNorm(config.hidden_size)
self.gru_cell = nn.GRUCell(config.hidden_size, config.hidden_size)
# Hybrid Attention
if self.use_hybrid_attention:
self.attn = SlidingWindowAttention(config)
# Slow State (DSRN)
self.linear_read = nn.Linear(self.state_size, config.hidden_size, bias=False)
self.linear_gate = nn.Linear(config.hidden_size, self.state_size)
self.linear_memory = nn.Linear(config.hidden_size, self.state_size)
# -- Surprise Mechanism --
self.linear_pred = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.surprise_lambda = nn.Parameter(torch.zeros(self.state_size))
# Feed-Forward
if self.use_rmsnorm:
self.norm_ff = HymbaRMSNorm(config.hidden_size)
else:
self.norm_ff = nn.LayerNorm(config.hidden_size)
# Simple MLP: Linear -> GELU -> Linear
# mlp_up / mlp_act / mlp_down are the ONLY registered submodules.
# No self.mlp alias — that caused double-registration and spurious "missing keys".
intermediate_size = getattr(
config, "intermediate_size", int(config.hidden_size * getattr(config, "mlp_ratio", 4.0))
)
# Use getattr guard so configs loaded from old JSON (pre-mlp_bias field) default safely.
_mlp_bias = getattr(config, "mlp_bias", False)
self.mlp_up = nn.Linear(config.hidden_size, intermediate_size, bias=_mlp_bias)
self.mlp_act = nn.GELU()
self.mlp_down = nn.Linear(intermediate_size, config.hidden_size, bias=_mlp_bias)
def forward(
self, x: torch.Tensor, state_prev: Tuple[torch.Tensor, ...], **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# Unpack state
# Supports (h, c) or (h, c, k_attn, v_attn)
h_prev = state_prev[0]
c_prev = state_prev[1]
if self.use_triton and x.is_cuda:
# Placeholder for Triton
pass
# Use Parallel Kernel
x_out, h_new, c_new, gate_stats = dsrn_parallel_kernel(self, x, h_prev, c_prev)
if self.use_hybrid_attention:
# Re-apply norm for attention branch (cleanest for surgical transplant)
x_norm = self.norm_fast(x)
# Extract attention state from tuple if present (h, c, k_attn, v_attn)
# HF state structure is now: (h, c, k_attn, v_attn)
# But wait, past_key_values in forward loop is just (h,c) from legacy code.
# We need to expand the state tuple to include attention KV.
attn_kv = None
if len(state_prev) == 4:
attn_kv = (state_prev[2], state_prev[3])
attn_out, new_attn_kv = self.attn(x_norm, past_key_values=attn_kv, **kwargs)
x_out = x_out + attn_out
# Update state with new KV
if new_attn_kv is not None:
h_new_full = (h_new, c_new, new_attn_kv[0], new_attn_kv[1])
else:
h_new_full = (h_new, c_new)
else:
h_new_full = (h_new, c_new)
return x_out, h_new_full, gate_stats
class EchoPreTrainedModel(PreTrainedModel):
config_class = EchoConfig
base_model_prefix = "model"
_no_split_modules = ["DSRNBlock"]
# Silently drop legacy mlp.0.*/mlp.1.*/mlp.2.* alias keys if they exist in old
# local training checkpoints from before the self.mlp aliasing was removed.
# The canonical names are mlp_up.* / mlp_act.* / mlp_down.* which load fine.
_keys_to_ignore_on_load_unexpected = [
r".*\.mlp\.0\..*",
r".*\.mlp\.1\..*",
r".*\.mlp\.2\..*",
]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
class EchoModel(EchoPreTrainedModel):
supports_gradient_checkpointing = True
_supports_attention_backend = True
def __init__(self, config: EchoConfig):
super().__init__(config)
self.embed_dim = config.embed_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.state_dim = config.embed_dim * config.num_heads
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.blocks = nn.ModuleList([DSRNBlock(config) for _ in range(config.num_layers)])
if getattr(config, "use_rmsnorm", False):
self.final_norm = HymbaRMSNorm(config.hidden_size)
else:
self.final_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False
self.post_init()
# --- ZOMBIE GRADIENT PATCH (FIXED) ---
# Fixed: Now using controlled bias defaults to 1.0 to encourage open gates initially
bias_val = getattr(config, "gate_bias_init", 1.0)
for block in self.blocks:
nn.init.constant_(block.linear_gate.bias, bias_val)
# Init Surprise
if (
block.linear_pred.weight.dtype in (torch.bfloat16, torch.float16)
and block.linear_pred.weight.is_cuda
):
_device = block.linear_pred.weight.device
_dtype = block.linear_pred.weight.dtype
temp_w = torch.empty_like(
block.linear_pred.weight, dtype=torch.float32, device="cpu"
)
nn.init.orthogonal_(temp_w, gain=0.1)
with torch.no_grad():
block.linear_pred.weight.copy_(temp_w.to(device=_device, dtype=_dtype))
else:
nn.init.orthogonal_(block.linear_pred.weight, gain=0.1)
nn.init.zeros_(block.surprise_lambda)
# CRITICAL: Zero-Init Residual Output (Identity Start)
nn.init.zeros_(block.mlp_down.weight)
if block.mlp_down.bias is not None:
nn.init.zeros_(block.mlp_down.bias)
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Enable/disable gradient checkpointing."""
self.gradient_checkpointing = enable
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_dsrn_telemetry: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_len = input_ids.shape
x = self.embedding(input_ids)
elif inputs_embeds is not None:
batch_size, seq_len, _ = inputs_embeds.shape
x = inputs_embeds
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = x.device
# Initialize states if not provided or if it's an empty Cache object
is_empty_cache = (
hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0
)
if past_key_values is None or is_empty_cache:
past_key_values = []
for _ in range(self.num_layers):
h = torch.zeros(batch_size, self.embed_dim, device=device, dtype=x.dtype)
c = torch.zeros(batch_size, self.state_dim, device=device, dtype=x.dtype)
past_key_values.append((h, c))
current_states = past_key_values
next_states = []
all_gate_stats = [] if output_dsrn_telemetry else None
all_c_states = [] if output_dsrn_telemetry else None
# Layer-Major Execution
for i, block in enumerate(self.blocks):
# Handle potential DynamicCache structure or list of tuples
if hasattr(current_states, "__getitem__"):
state_i = current_states[i]
else:
state_i = current_states[i]
if len(state_i) == 2:
# DSRN Only
pass
elif len(state_i) == 4:
# DSRN + Attention State
pass
else:
# Fallback for empty/malformed states
h_prev = torch.zeros(batch_size, self.embed_dim, device=device)
c_prev = torch.zeros(batch_size, self.state_dim, device=device)
state_i = (h_prev, c_prev)
# Use gradient checkpointing if enabled
if self.gradient_checkpointing and self.training:
# Checkpointing complex states is tricky, usually just pass h/c
out = torch.utils.checkpoint.checkpoint(block, x, state_i, use_reentrant=False)
else:
out = block(x, state_i, **kwargs)
x = out[0]
next_states.append(out[1])
if output_dsrn_telemetry:
all_gate_stats.append(out[2])
all_c_states.append(out[1][1])
x = self.final_norm(x)
if isinstance(current_states, EchoCache):
current_states.states = next_states
next_states = current_states
elif EchoCache is not None:
next_states = EchoCache(next_states)
if output_dsrn_telemetry:
return x, next_states, all_c_states, all_gate_stats
return x, next_states
class EchoForCausalLM(EchoPreTrainedModel, GenerationMixin):
_is_causal = True
supports_gradient_checkpointing = True
_supports_cache_class = False
_supports_static_cache = False
main_input_name = "input_ids"
# Required by the modern HF tie_weights() mechanism (transformers ≥ 4.47).
# Without this dict being non-None, tie_weights() returns early even when
# tie_word_embeddings=True and get_input/output_embeddings() are both defined.
_tied_weights_keys = {"lm_head.weight": "model.embedding.weight"}
@property
def _keys_to_ignore_on_load_missing(self):
# When mlp_bias=False (the default, and the setting for all v0.1.2 checkpoints),
# bias tensors are not present in the checkpoint and should not trigger warnings.
# When mlp_bias=True, these keys WILL exist in the checkpoint — do not silence them.
if not getattr(self.config, "mlp_bias", False):
return [r"model\.blocks\.\d+\.mlp_(up|down)\.bias"]
return []
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# Defense-in-depth: if mlp_bias=False but bias tensors were somehow initialized
# (e.g. an old code path created them), zero them out to prevent NaN/Inf
# corruption when running in bfloat16.
if not getattr(model.config, "mlp_bias", False):
zeroed = 0
with torch.no_grad():
for name, param in model.named_parameters():
if "mlp_up.bias" in name or "mlp_down.bias" in name:
param.zero_()
zeroed += 1
if zeroed:
import warnings
warnings.warn(
f"Zeroed {zeroed} MLP bias tensor(s) that were missing from the "
f"checkpoint. This indicates a config/checkpoint mismatch. "
f"Ensure mlp_bias=False in EchoConfig for v0.1.2 checkpoints.",
UserWarning,
)
return model
def __init__(self, config: EchoConfig):
super().__init__(config)
self.model = EchoModel(config)
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embedding
def set_input_embeddings(self, value):
self.model.embedding = value
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
"""Enable/disable gradient checkpointing."""
self.model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_dsrn_telemetry: Optional[bool] = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else getattr(self.config, "output_hidden_states", False)
)
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", True)
return_dict = (
return_dict
if return_dict is not None
else getattr(self.config, "use_return_dict", True)
)
'''
If kwargs is getting overloaded with extra args HF generate passes,
we safely extract kwargs here.
'''
# Pass position_ids explicitly alongside **kwargs
kwargs["position_ids"] = position_ids
model_out = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_dsrn_telemetry=output_dsrn_telemetry,
**kwargs,
)
hidden_states = model_out[0]
new_states = model_out[1]
if len(model_out) > 2:
self._latest_c_states = model_out[2]
self._latest_gate_stats = model_out[3]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits, new_states)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_states if use_cache else None,
hidden_states=None, # EchoModel doesn't expose internal states yet
attentions=None, # EchoModel doesn't expose attention weights yet
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, **kwargs
):
# If past_key_values is a DynamicCache, we need to extract the underlying list of tuples
# if the custom cache hasn't taken over yet. But actually, HF doesn't know about our 4-tuples.
# So we should just let EchoModel handle it. If HF gave us a DynamicCache, it might be empty
# or mangled.
if (
past_key_values is not None
and not isinstance(past_key_values, (list, tuple))
and not isinstance(past_key_values, EchoCache)
):
# It's a DynamicCache. It's likely from the first generation step.
# We can't use it directly because it stripped our (h,c).
# But wait, on the VERY first generation step, past_key_values is None, then EchoModel returns EchoCache.
# On subsequent steps we get EchoCache.
# So if we get a DynamicCache, it means someone passed past_key_values explicitly to generate(),
# or HF auto-created it on step 0 and passed it to step 1 incorrectly.
pass
# In newer transformers, past_key_values could be a DynamicCache.
# Check if it's effectively empty.
is_empty = False
if past_key_values is None:
is_empty = True
elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0:
is_empty = True
elif isinstance(past_key_values, list) and len(past_key_values) == 0:
is_empty = True
# If past_key_values is used, we only need the last token
if not is_empty:
input_ids = input_ids[:, -1:]
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache": kwargs.get("use_cache"),
}
# Pass through extra kwargs like output_dsrn_telemetry
model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
"""
Reorders cache for beam search or contrastive search.
past_key_values: List[Tuple(h, c, ...)]
"""
if past_key_values is None:
return None
reordered_past = []
for layer_past in past_key_values:
# Each layer_past is a tuple of tensors (h, c) or (h, c, k, v)
reordered_layer_past = tuple(
p.index_select(0, beam_idx.to(p.device)) for p in layer_past
)
reordered_past.append(reordered_layer_past)
return reordered_past
class EchoForSequenceClassification(EchoPreTrainedModel):
"""
Echo-DSRN with a sequence-level classification head.
This model is the *terminal* form of a fine-tuned classifier: it exposes
only a ``classify()`` convenience method and a standard HF ``forward()``
that returns :class:`~transformers.modeling_outputs.SequenceClassifierOutputWithPast`.
It intentionally does **not** inherit :class:`~transformers.GenerationMixin` so
chat-completion endpoints cannot be used accidentally.
Typical construction path
-------------------------
1. Load ``EchoForCausalLM`` + LoRA adapter via :func:`merge_and_export`
(see ``scripts/merge_clf_adapter.py``).
2. The resulting merged weights are saved as ``EchoForSequenceClassification``
alongside a ``config.json`` that carries ``num_labels``, ``id2label``, and
``label2id``.
3. End-users load with::
from echo_dsrn import EchoForSequenceClassification
model = EchoForSequenceClassification.from_pretrained("your/hub-id")
label, probs = model.classify("some text")
"""
# Do NOT add GenerationMixin — this model must not generate text.
main_input_name = "input_ids"
def __init__(self, config: EchoConfig):
super().__init__(config)
self.num_labels = getattr(config, "num_labels", 2)
self.model = EchoModel(config)
classifier_dropout = getattr(config, "classifier_dropout", 0.0)
self.dropout = nn.Dropout(classifier_dropout) if classifier_dropout > 0.0 else nn.Identity()
self.classifier = nn.Linear(config.embed_dim, self.num_labels, bias=True)
self.post_init()
# ------------------------------------------------------------------
# HF embedding hooks (required by PreTrainedModel)
# ------------------------------------------------------------------
def get_input_embeddings(self):
return self.model.embedding
def set_input_embeddings(self, value):
self.model.embedding = value
def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
self.model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
"""
Parameters
----------
labels:
- ``num_labels == 1``: regression target (``torch.float``).
- ``num_labels > 1``, single integer per sample: cross-entropy class index.
- ``num_labels > 1``, float vector per sample: multi-label BCE.
"""
return_dict = (
return_dict
if return_dict is not None
else getattr(self.config, "use_return_dict", True)
)
kwargs["position_ids"] = position_ids
model_out = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
hidden_states = model_out[0] # (B, T, D)
new_states = model_out[1]
# --- Pooling: last non-padding token ---
if attention_mask is not None:
# Find the index of the last 1 in each row of attention_mask
seq_lengths = attention_mask.sum(dim=1) - 1 # (B,)
seq_lengths = seq_lengths.clamp(min=0)
else:
# No mask: use the true last token
if input_ids is not None:
seq_lengths = torch.full(
(hidden_states.size(0),),
hidden_states.size(1) - 1,
dtype=torch.long,
device=hidden_states.device,
)
else:
seq_lengths = torch.full(
(hidden_states.size(0),),
hidden_states.size(1) - 1,
dtype=torch.long,
device=hidden_states.device,
)
# Gather last-token hidden states: (B, D)
pooled = hidden_states[
torch.arange(hidden_states.size(0), device=hidden_states.device), seq_lengths
]
pooled = self.dropout(pooled)
logits = self.classifier(pooled) # (B, num_labels)
# --- Loss ---
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression
loss_fct = nn.MSELoss()
loss = loss_fct(logits.squeeze(-1), labels.float())
elif labels.dtype in (torch.float, torch.float16, torch.bfloat16):
# Multi-label binary classification
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels.float())
else:
# Standard multi-class
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits, new_states)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_states if use_cache else None,
hidden_states=None,
attentions=None,
)
# ------------------------------------------------------------------
# Convenience inference API
# ------------------------------------------------------------------
@torch.inference_mode()
def classify(
self,
text: str,
tokenizer,
device: Optional[str] = None,
return_probabilities: bool = True,
) -> Tuple[str, Optional[torch.Tensor]]:
"""
High-level classification helper.
Parameters
----------
text:
Raw string to classify.
tokenizer:
A HuggingFace ``PreTrainedTokenizer`` compatible with the model.
device:
Optional device string (e.g. ``"cuda"``). Defaults to the device
of the model's first parameter.
return_probabilities:
If ``True`` (default), also return a probability tensor (softmax
for multi-class, sigmoid for binary/multi-label).
Returns
-------
label : str
The predicted label string from ``config.id2label``.
probabilities : Tensor or None
Shape ``(num_labels,)`` probability vector, or ``None`` if
``return_probabilities=False``.
"""
if device is None:
try:
device = str(next(self.parameters()).device)
except StopIteration:
device = "cpu"
self.eval()
# Format text if baked-in templates exist
sys_prompt = getattr(self.config, "system_prompt", None)
usr_template = getattr(self.config, "user_template", None)
if sys_prompt and usr_template:
messages = [{"role": "system", "content": sys_prompt}]
messages.append({"role": "user", "content": usr_template.format(text=text)})
# Format using the tokenizer's chat template
try:
formatted_text = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
except Exception:
formatted_text = text
else:
formatted_text = text
enc = tokenizer(formatted_text, return_tensors="pt", truncation=True)
enc = {k: v.to(device) for k, v in enc.items()}
output = self(**enc)
logits = output.logits # (1, num_labels)
if self.num_labels == 1:
# Regression: return raw value
pred_label = str(logits.squeeze().item())
probs = None
elif self.num_labels == 2:
probs_t = torch.softmax(logits, dim=-1).squeeze(0) if return_probabilities else None
pred_id = int(logits.argmax(dim=-1).item())
pred_label = getattr(self.config, "id2label", {0: "0", 1: "1"}).get(
pred_id, str(pred_id)
)
probs = probs_t
else:
probs_t = torch.softmax(logits, dim=-1).squeeze(0) if return_probabilities else None
pred_id = int(logits.argmax(dim=-1).item())
pred_label = getattr(self.config, "id2label", {}).get(pred_id, str(pred_id))
probs = probs_t
return pred_label, probs
@classmethod
def from_causal_lm(
cls,
causal_lm_model,
num_labels: int = 2,
id2label: Optional[dict] = None,
label2id: Optional[dict] = None,
classifier_dropout: float = 0.0,
label_token_ids: Optional[List[int]] = None,
system_prompt: Optional[str] = None,
user_template: Optional[str] = None,
) -> "EchoForSequenceClassification":
"""
Construct an :class:`EchoForSequenceClassification` from a fully
merged :class:`EchoForCausalLM` instance (i.e. after LoRA weights
have been merged via ``peft.merge_adapter``).
The backbone weights are copied; the ``lm_head`` is discarded.
Classifier head initialisation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If ``label_token_ids`` is provided (one token ID per class), the
classifier weight rows are seeded directly from the corresponding
``lm_head`` weight rows. This is the correct initialisation for
**generative** adapters that were fine-tuned to emit a label token
(e.g. ``"0"`` or ``"1"``): the backbone already knows how to push
the last hidden state toward those tokens, so we preserve that signal
instead of starting from random.
Parameters
----------
causal_lm_model:
A loaded (and optionally LoRA-merged) ``EchoForCausalLM`` instance.
num_labels:
Number of output classes.
id2label:
Optional mapping ``{int -> str}`` for label names.
label2id:
Optional reverse mapping ``{str -> int}``.
classifier_dropout:
Dropout probability before the classification head.
label_token_ids:
Optional list of ``num_labels`` token IDs. When supplied, row
``i`` of the ``lm_head`` weight matrix is copied into row ``i``
of the classifier weight matrix, seeding the head from the
causal model's learned token distributions.
Example for Echo-DSRN NSFW adapter::
label_token_ids=[29900, 29896] # token IDs for "0" and "1"
Returns
-------
EchoForSequenceClassification
"""
if id2label is None:
id2label = {i: str(i) for i in range(num_labels)}
if label2id is None:
label2id = {v: k for k, v in id2label.items()}
# Validate label_token_ids length
if label_token_ids is not None and len(label_token_ids) != num_labels:
raise ValueError(
f"label_token_ids has {len(label_token_ids)} entries but num_labels={num_labels}. "
"Must provide exactly one token ID per class."
)
# Clone config and inject classification fields
config = causal_lm_model.config
config.num_labels = num_labels
config.id2label = {int(k): v for k, v in id2label.items()}
config.label2id = label2id
config.classifier_dropout = classifier_dropout
if system_prompt is not None:
config.system_prompt = system_prompt
if user_template is not None:
config.user_template = user_template
# Carry dtype forward so save_pretrained serialises it correctly
if hasattr(causal_lm_model, "dtype"):
config.torch_dtype = str(causal_lm_model.dtype).replace("torch.", "")
# Update auto_map so Hub users get the right class on from_pretrained
config.auto_map = {
"AutoConfig": "configuration_echo.EchoConfig",
"AutoModel": "modeling_echo.EchoModel",
"AutoModelForSequenceClassification": ("modeling_echo.EchoForSequenceClassification"),
}
# Build the classifier wrapper
clf_model = cls(config)
# Copy backbone weights
backbone_sd = causal_lm_model.model.state_dict()
missing, unexpected = clf_model.model.load_state_dict(backbone_sd, strict=True)
if missing:
import warnings
warnings.warn(
f"EchoForSequenceClassification.from_causal_lm: "
f"missing backbone keys: {missing}",
UserWarning,
)
if unexpected:
import warnings
warnings.warn(
f"EchoForSequenceClassification.from_causal_lm: "
f"unexpected backbone keys: {unexpected}",
UserWarning,
)
# --- Seed classifier head from lm_head rows (generative adapter path) ---
if label_token_ids is not None:
lm_head_weight = causal_lm_model.lm_head.weight # (vocab_size, embed_dim)
with torch.no_grad():
for label_idx, token_id in enumerate(label_token_ids):
clf_model.classifier.weight[label_idx].copy_(lm_head_weight[token_id])
# Zero-init bias so initial scores are purely from the weight rows
torch.nn.init.zeros_(clf_model.classifier.bias)
# --- Cast entire model to the source dtype ---
# cls(config) initialises weights in float32 by default.
# We cast everything uniformly AFTER all weight copies so that both
# the backbone and the seeded classifier head end up in the same precision.
src_dtype = causal_lm_model.dtype # e.g. torch.bfloat16
if src_dtype != torch.float32:
clf_model = clf_model.to(src_dtype)
# Persist in config using the current (non-deprecated) field name
config.dtype = str(src_dtype).replace("torch.", "")
return clf_model