ymodel3-n1 / ymodel3_eval.py
SnifferCaptain's picture
Upload 5 files
b425c8f verified
"""Self-contained ymodel3 inference module.
Only depends on: torch, safetensors.
No dependency on kernel.*, model.ymodel3, transformers.
"""
from __future__ import annotations
import json
import math
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file as load_safetensors
# ── Config ──────────────────────────────────────────────────────────
class YConfig3:
model_type = "ynet3"
def __init__(self, **kwargs):
self.dropout = kwargs.get("dropout", 0.0)
self.bos_token_id = kwargs.get("bos_token_id", 151644)
self.eos_token_id = kwargs.get("eos_token_id", 151645)
self.pad_token_id = kwargs.get("pad_token_id", 151643)
self.hidden_act = kwargs.get("hidden_act", "silu")
self.hidden_size = kwargs.get("hidden_size", 768)
self.num_hidden_layers = kwargs.get("num_hidden_layers", 8)
self.max_position_embeddings = kwargs.get("max_position_embeddings", 8192)
self.vocab_size = kwargs.get("vocab_size", 6400)
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
self.rope_theta = kwargs.get("rope_theta", 5e4)
self.rope_scaling = kwargs.get("rope_scaling", None)
self.dtype = kwargs.get("dtype", "float32")
self.self_distill = kwargs.get("self_distill", True)
self.intermediate_size = kwargs.get("intermediate_size", 1536)
self.expert_intermediate_size = kwargs.get("expert_intermediate_size", None) or self.intermediate_size
self.n_routed_experts = kwargs.get("n_routed_experts", 0)
self.moe_topk = kwargs.get("moe_topk", 2)
self.score_func = kwargs.get("score_func", "softmax")
self.n_shared_experts = kwargs.get("n_shared_experts", 0)
self.top_k_layer_dense = kwargs.get("top_k_layer_dense", 1)
self.aux_loss_alpha = kwargs.get("aux_loss_alpha", 0.02)
self.seq_aux = kwargs.get("seq_aux", False)
self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
self.noisy_expert = kwargs.get("noisy_expert", 0.0)
self.moe_backend = kwargs.get("moe_backend", "compact")
self.router_bias_enabled = kwargs.get("router_bias_enabled", True)
self.router_bias_update_rate = kwargs.get("router_bias_update_rate", 1e-3)
self.router_bias_clamp = kwargs.get("router_bias_clamp", 5.0)
self.num_heads = kwargs.get("num_heads", 12)
self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 64)
self.mla_qk_nope_head_dim = kwargs.get("mla_qk_nope_head_dim", 64)
self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32)
self.mla_attn_impl = kwargs.get("mla_attn_impl", "absorb")
self.qkv_lora = kwargs.get("qkv_lora", False)
@property
def head_dim(self) -> int:
return self.mla_qk_nope_head_dim + self.mla_qk_rope_head_dim
def scale_lvl(self, lvl: int = 0):
if lvl == 0:
self.hidden_size = 1024
self.num_hidden_layers = 8
self.num_heads = 8
self.mla_kv_lora_rank = 256
self.mla_qk_nope_head_dim = 192
self.mla_qk_rope_head_dim = 64
self.intermediate_size = 2048
self.expert_intermediate_size = 512
self.n_routed_experts = 16
self.moe_topk = 1
self.n_shared_experts = 0
self.top_k_layer_dense = 1
self.router_bias_update_rate = 1e-3
elif lvl == -1:
self.hidden_size = 768
self.num_hidden_layers = 8
self.num_heads = 6
self.mla_kv_lora_rank = 128
self.mla_qk_nope_head_dim = 64
self.mla_qk_rope_head_dim = 64
self.intermediate_size = 1536
self.expert_intermediate_size = 768
self.n_routed_experts = 0
self.moe_topk = 2
self.n_shared_experts = 0
self.top_k_layer_dense = 8
elif lvl == -2:
self.hidden_size = 512
self.num_hidden_layers = 4
self.num_heads = 4
self.mla_kv_lora_rank = 128
self.mla_qk_nope_head_dim = 64
self.mla_qk_rope_head_dim = 32
self.intermediate_size = 1024
self.expert_intermediate_size = 512
self.n_routed_experts = 0
self.moe_topk = 2
self.n_shared_experts = 0
self.top_k_layer_dense = 4
else:
raise ValueError(f"invalid ymodel3 scale level: {lvl}")
return self
@classmethod
def from_json_file(cls, path: str) -> "YConfig3":
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return cls(**data)
@classmethod
def from_dict(cls, data: dict) -> "YConfig3":
return cls(**data)
# ── Basic modules ──────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (out * self.weight.float()).to(x.dtype)
class SEBlock(nn.Module):
def __init__(self, dim: int, reduction: int = 16, act: Optional[nn.Module] = None):
super().__init__()
reduction = max(reduction, dim // reduction)
self.se = nn.Sequential(
nn.Linear(dim, reduction, bias=False),
act or nn.SiLU(),
nn.Linear(reduction, dim, bias=False),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.se(x)
# ── RoPE helpers ──────────────────────────────────────────────────
def _yarn_linear_ramp(low: float, high: float, dim: int) -> torch.Tensor:
if low == high:
high += 0.001
linear = (torch.arange(dim, dtype=torch.float32) - low) / (high - low)
return torch.clamp(linear, 0.0, 1.0)
def _yarn_correction_dim(num_rotations: float, dim: int, theta: float, max_position_embeddings: int) -> float:
return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(theta))
def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
rope_scaling: Optional[dict] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
attention_factor = 1.0
if rope_scaling and str(rope_scaling.get("type", "yarn")).lower() == "yarn":
factor = float(rope_scaling.get("factor", 1.0))
if factor > 1.0:
original = int(rope_scaling.get("original_max_position_embeddings", end))
beta_fast = float(rope_scaling.get("beta_fast", 32.0))
beta_slow = float(rope_scaling.get("beta_slow", 1.0))
low = math.floor(_yarn_correction_dim(beta_fast, dim, theta, original))
high = math.ceil(_yarn_correction_dim(beta_slow, dim, theta, original))
ramp = _yarn_linear_ramp(low, high, dim // 2)
freqs = freqs / factor * (1.0 - ramp) + freqs * ramp
attention_factor = float(rope_scaling.get("attention_factor", 1.0))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attention_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attention_factor
return freqs_cos, freqs_sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1)
def apply_rope_to_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
if cos.dim() == 2:
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
elif cos.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (x * cos) + (rotate_half(x) * sin)
# ── Attention ──────────────────────────────────────────────────────
class MLGA(nn.Module):
"""Multihead Latent Gated Attention"""
def __init__(self, config: YConfig3, layer_id: int):
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.dropout = config.dropout
self.kv_lora_rank = config.mla_kv_lora_rank
self.qk_nope_head_dim = config.mla_qk_nope_head_dim
self.qk_rope_head_dim = config.mla_qk_rope_head_dim
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.attn_impl = config.mla_attn_impl
self.softmax_scale = self.qk_head_dim ** -0.5
self.out_dim = self.num_heads * self.kv_lora_rank
self.wq = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
self.wkv_a = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
self.kv_norm = RMSNorm(self.kv_lora_rank, config.rms_norm_eps)
self.wkv_b = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
self.z_proj = nn.Linear(self.hidden_size, self.out_dim, bias=False)
self.o_proj = nn.Linear(self.out_dim, self.hidden_size, bias=False)
def _project_q(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
bsz, seq_len, _ = x.shape
q = self.wq(x)
q = q.reshape(bsz, seq_len, self.num_heads, self.qk_head_dim)
return q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
def _project_kv(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw = self.wkv_a(x)
c_kv, k_pe = raw.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
c_kv = self.kv_norm(c_kv)
k_pe = apply_rope_to_single(k_pe.unsqueeze(1), cos, sin).permute(0, 2, 1, 3)
return c_kv, k_pe
def _explicit_kv(self, c_kv: torch.Tensor, k_pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
bsz, seq_len, _ = c_kv.shape
k_nope = self.wkv_b(c_kv).reshape(bsz, seq_len, self.num_heads, self.qk_nope_head_dim)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.num_heads, -1)], dim=-1)
v = c_kv.unsqueeze(2).expand(-1, -1, self.num_heads, -1)
return k, v
def _attention_mask(self, attention_mask: Optional[torch.Tensor], bsz: int, seq_len: int, total_len: int):
if attention_mask is None:
return None
if attention_mask.shape[-1] != total_len:
attention_mask = attention_mask[..., -total_len:]
mask = attention_mask.reshape(bsz, 1, 1, total_len).bool()
return mask.expand(bsz, self.num_heads, seq_len, total_len)
def _forward_sdpa(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
c_kv: torch.Tensor,
k_pe: torch.Tensor,
z: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
bsz, seq_len, _, _ = q_nope.shape
total_len = c_kv.shape[1]
k, v = self._explicit_kv(c_kv, k_pe)
q = torch.cat([q_nope, q_pe], dim=-1).permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_mask = self._attention_mask(attention_mask, bsz, seq_len, total_len)
is_causal = attention_mask is None and seq_len == total_len
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.softmax_scale,
)
out = out.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.out_dim)
out = out * torch.sigmoid(z)
return self.o_proj(out)
def _forward_absorb(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
c_kv: torch.Tensor,
k_pe: torch.Tensor,
z: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
bsz, seq_len, _, _ = q_nope.shape
total_len = c_kv.shape[1]
w = self.wkv_b.weight.reshape(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
q_nope_c = torch.einsum("bshd,hdc->bshc", q_nope, w)
scores = torch.einsum("bshc,btc->bsht", q_nope_c, c_kv)
scores = scores + torch.einsum("bshr,btr->bsht", q_pe, k_pe.squeeze(2))
scores = scores * self.softmax_scale
causal = torch.full((seq_len, seq_len), float("-inf"), device=scores.device, dtype=scores.dtype)
causal = torch.triu(causal, diagonal=1).reshape(1, seq_len, 1, seq_len)
scores = scores + F.pad(causal, (total_len - seq_len, 0), value=0.0)
if attention_mask is not None:
if attention_mask.shape[-1] != total_len:
attention_mask = attention_mask[..., -total_len:]
scores = scores + (1.0 - attention_mask.reshape(bsz, 1, 1, total_len).float()) * -1e9
probs = torch.softmax(scores.float(), dim=-1).to(q_nope.dtype)
out = torch.einsum("bsht,btc->bshc", probs, c_kv).reshape(bsz, seq_len, self.out_dim)
out = out * torch.sigmoid(z)
return self.o_proj(out)
def forward(
self,
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
bsz, seq_len, _ = x.shape
cos, sin = position_embeddings
if cos.dim() == 2:
cos = cos[:seq_len, : self.qk_rope_head_dim]
sin = sin[:seq_len, : self.qk_rope_head_dim]
else:
cos = cos[:, :seq_len, : self.qk_rope_head_dim]
sin = sin[:, :seq_len, : self.qk_rope_head_dim]
q_nope, q_pe = self._project_q(x)
q_pe = apply_rope_to_single(q_pe.permute(0, 2, 1, 3), cos, sin).permute(0, 2, 1, 3)
c_kv, k_pe = self._project_kv(x, cos, sin)
z = self.z_proj(x)
if past_key_values is not None:
past_c, past_pe = past_key_values
c_kv = torch.cat([past_c, c_kv], dim=1)
k_pe = torch.cat([past_pe, k_pe], dim=1)
new_past = (c_kv, k_pe) if use_cache else None
if self.attn_impl == "naive":
out = self._forward_sdpa(q_nope, q_pe, c_kv, k_pe, z, attention_mask)
else:
out = self._forward_absorb(q_nope, q_pe, c_kv, k_pe, z, attention_mask)
out = F.dropout(out, p=self.dropout, training=self.training)
return out, new_past
# ── FFN / MoE ──────────────────────────────────────────────────────
_ACT_FNS = {
"silu": F.silu,
"swish": F.silu,
"relu": F.relu,
"gelu": lambda x: F.gelu(x, approximate="tanh"),
"sigmoid": torch.sigmoid,
}
_ACT_MODULES = {
"silu": nn.SiLU,
"swish": nn.SiLU,
"relu": nn.ReLU,
"gelu": lambda: nn.GELU(approximate="tanh"),
"sigmoid": nn.Sigmoid,
}
class DenseFFN(nn.Module):
def __init__(self, config: YConfig3, intermediate_size: Optional[int] = None):
super().__init__()
inter = intermediate_size or config.intermediate_size
self.up_proj = nn.Linear(config.hidden_size, inter, bias=False)
self.gate_proj = nn.Linear(config.hidden_size, inter, bias=False)
self.down_proj = nn.Linear(inter, config.hidden_size, bias=False)
self.hidden_act = config.hidden_act
self.act = _ACT_FNS.get(config.hidden_act, F.silu)
self.dropout = config.dropout
def forward(self, x: torch.Tensor) -> torch.Tensor:
up, gate = self.up_proj(x), self.gate_proj(x)
up = self.act(gate) * up
up = F.dropout(up, p=self.dropout, training=self.training)
return self.down_proj(up)
class MoEGate(nn.Module):
def __init__(self, config: YConfig3):
super().__init__()
self.n_routed_experts = config.n_routed_experts
self.topk = min(config.moe_topk, max(1, config.n_routed_experts))
self.score_func = config.score_func
self.norm_topk_prob = config.norm_topk_prob
self.aux_loss_alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.router_bias_enabled = config.router_bias_enabled
self.router_bias_update_rate = config.router_bias_update_rate
self.router_bias_clamp = config.router_bias_clamp
self.weight = nn.Linear(int(config.hidden_size), int(self.n_routed_experts), bias=False)
if self.router_bias_enabled:
self.register_buffer("router_bias", torch.zeros(self.n_routed_experts), persistent=True)
else:
self.register_buffer("router_bias", None, persistent=False)
def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None):
bsz, seq_len, hidden = x.shape
flat = x.reshape(-1, hidden)
route_logits = self.weight(flat)
if self.score_func == "softmax":
route_scores = torch.softmax(route_logits.float(), dim=-1).to(x.dtype)
elif self.score_func == "sigmoid":
route_scores = torch.sigmoid(route_logits.float()).to(x.dtype)
else:
raise ValueError(f"unsupported MoE score_func: {self.score_func}")
choice_scores = route_scores
if self.router_bias is not None:
choice_scores = choice_scores + self.router_bias.to(dtype=choice_scores.dtype).unsqueeze(0)
topk_idx = torch.topk(choice_scores, k=self.topk, dim=-1, sorted=False).indices
topk_weight = route_scores.gather(1, topk_idx)
if self.topk > 1 and self.norm_topk_prob:
denom = topk_weight.float().sum(dim=-1, keepdim=True) + 1e-20
topk_weight = (topk_weight.float() / denom).to(x.dtype)
aux_loss = x.new_zeros((), dtype=x.dtype)
return (
topk_idx.reshape(bsz, seq_len, self.topk),
topk_weight.reshape(bsz, seq_len, self.topk),
aux_loss,
)
def _torch_moe_swiglu(
x: torch.Tensor,
topk_idx: torch.Tensor,
topk_weight: torch.Tensor,
w_up: torch.Tensor,
w_down: torch.Tensor,
activation: str = "silu",
) -> torch.Tensor:
"""Pure PyTorch MoE SwiGLU forward (inference only, no noisy_expert)."""
original_shape = x.shape
x_flat = x.reshape(-1, x.shape[-1])
idx = topk_idx.reshape(x_flat.shape[0], -1)
weight = topk_weight.reshape(x_flat.shape[0], -1)
y = torch.zeros_like(x_flat)
n_experts = w_up.shape[0]
inter = w_down.shape[-1]
act_fn = _ACT_FNS.get(activation, F.silu)
for expert_id in range(n_experts):
token_pos, choice_pos = torch.where(idx == expert_id)
if token_pos.numel() == 0:
continue
inp = x_flat[token_pos]
uv = F.linear(inp, w_up[expert_id])
up, gate = uv.split(inter, dim=-1)
hidden = act_fn(gate) * up
out = F.linear(hidden, w_down[expert_id])
route_w = weight[token_pos, choice_pos].unsqueeze(-1)
y.index_add_(0, token_pos, out * route_w)
return y.reshape(original_shape)
class YMoE(nn.Module):
"""Pure PyTorch eval MoE (no Triton dependency)."""
def __init__(self, config: YConfig3, layer_id: int):
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.expert_intermediate_size = config.expert_intermediate_size
self.intermediate_size = self.expert_intermediate_size
self.n_routed_experts = config.n_routed_experts
self.use_moe = self.n_routed_experts > 0 and layer_id >= config.top_k_layer_dense
self.noisy_expert = config.noisy_expert
if not self.use_moe:
self.dense = DenseFFN(config)
self.gate = None
self.w_up = None
self.w_down = None
return
self.dense = None
self.gate = MoEGate(config)
self.w_up = nn.Parameter(torch.empty(self.n_routed_experts, 2 * self.expert_intermediate_size, self.hidden_size))
self.w_down = nn.Parameter(torch.empty(self.n_routed_experts, self.hidden_size, self.expert_intermediate_size))
nn.init.kaiming_uniform_(self.w_up, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.w_down, a=math.sqrt(5))
def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None):
if not self.use_moe:
return self.dense(x), None
topk_idx, topk_weight, aux_loss = self.gate(x, aux_mask)
y = _torch_moe_swiglu(x, topk_idx, topk_weight, self.w_up, self.w_down, activation="silu")
return y, aux_loss
# ── Transformer block ──────────────────────────────────────────────
class YBlock3(nn.Module):
def __init__(self, config: YConfig3, layer_id: int):
super().__init__()
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.attn = MLGA(config, layer_id)
self.ffn = YMoE(config, layer_id)
act_module = _ACT_MODULES.get(config.hidden_act, nn.SiLU)
self.se1 = SEBlock(config.hidden_size, act=act_module() if isinstance(act_module, type) else act_module())
self.se2 = SEBlock(config.hidden_size, act=nn.SiLU())
def forward(
self,
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
past_key_values=None,
use_cache: bool = False,
attention_mask: Optional[torch.Tensor] = None,
aux_mask: Optional[torch.Tensor] = None,
**kwargs,
):
x0 = self.se1(self.input_layernorm(x))
attn_out, past = self.attn(
x0,
position_embeddings,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=use_cache,
)
x = x + attn_out
x0 = self.se2(self.post_attention_layernorm(x))
ffn_out, aux_loss = self.ffn(x0, aux_mask)
x = x + ffn_out
return x, past, aux_loss
# ── Full model ────────────────────────────────────────────────────
class YModel3(nn.Module):
def __init__(self, config: YConfig3):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_layers = config.num_hidden_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = config.dropout
self.use_self_distill = config.self_distill
self.layers = nn.ModuleList([YBlock3(config, i) for i in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(
dim=config.mla_qk_rope_head_dim,
end=config.max_position_embeddings,
theta=config.rope_theta,
rope_scaling=config.rope_scaling,
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[list] = None,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
bsz, seq_len = input_ids.shape
if use_cache and past_key_values is None:
past_key_values = [None] * self.num_layers
if cache_position is None:
if past_key_values is not None and past_key_values[0] is not None:
past_seen = past_key_values[0][0].shape[1]
else:
past_seen = 0
cache_position = torch.arange(past_seen, past_seen + seq_len, device=input_ids.device)
x = F.dropout(self.embed_tokens(input_ids), p=self.dropout, training=self.training)
if position_ids is None:
position_ids = cache_position
position_embeddings = (self.freqs_cos[position_ids].to(x.device), self.freqs_sin[position_ids].to(x.device))
aux_mask = None
new_past = [] if use_cache else None
aux_loss = None
for i, layer in enumerate(self.layers):
past = past_key_values[i] if past_key_values is not None else None
x, layer_past, layer_aux = layer(
x,
position_embeddings=position_embeddings,
past_key_values=past,
attention_mask=attention_mask,
use_cache=use_cache,
aux_mask=aux_mask,
)
if use_cache:
new_past.append(layer_past)
if self.training and layer_aux is not None:
aux_loss = layer_aux if aux_loss is None else aux_loss + layer_aux
return self.norm(x), new_past, None, aux_loss
class _InferenceOutput:
"""Simple container for model outputs (replaces transformers CausalLMOutputWithPast)."""
__slots__ = ("last_hidden_state", "logits", "past_key_values", "dist_loss", "aux_loss")
def __init__(self):
self.last_hidden_state = None
self.logits = None
self.past_key_values = None
self.dist_loss = None
self.aux_loss = None
def __setitem__(self, key, value):
setattr(self, key, value)
class YForCausalLM3(nn.Module):
"""Pure PyTorch CausalLM wrapper for ymodel3 inference (no transformers dependency)."""
config_class = YConfig3
def __init__(self, config: Optional[YConfig3] = None):
super().__init__()
self.config = config or YConfig3()
self.model = YModel3(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.OUT = _InferenceOutput()
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}.get(self.config.dtype)
if dtype is not None:
self.to(dtype)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[list] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
h, past_kvs, dist_loss, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_ids=kwargs.get("position_ids", None),
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(h[:, slice_indices, :])
self.OUT.__setitem__("last_hidden_state", h)
self.OUT.__setitem__("logits", logits)
self.OUT.__setitem__("past_key_values", past_kvs)
self.OUT.__setitem__("dist_loss", dist_loss)
self.OUT.__setitem__("aux_loss", aux_loss)
return self.OUT
def generate(
self,
inputs,
attention_mask=None,
max_new_tokens=8192,
temperature=0.85,
top_p=0.85,
top_k=50,
eos_token_id=None,
streamer=None,
use_cache=True,
num_return_sequences=1,
do_sample=True,
repetition_penalty=1.0,
**kwargs,
):
input_ids = kwargs.get("input_ids", inputs).repeat(num_return_sequences, 1)
attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
past_key_values = None
if streamer:
streamer.put(input_ids.cpu())
with torch.no_grad():
for _ in range(max_new_tokens):
if use_cache and past_key_values is not None:
outputs = self.forward(input_ids[:, -1:], None, past_key_values, use_cache=use_cache)
else:
outputs = self.forward(input_ids, attention_mask, past_key_values, use_cache=use_cache)
logits = outputs.logits[:, -1, :] / temperature
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]):
logits[i, torch.unique(input_ids[i])] /= repetition_penalty
if top_k > 0:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float("inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
logits[mask.scatter(1, sorted_indices, mask)] = -float("inf")
next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=-1)
past_key_values = outputs.past_key_values if use_cache else None
if streamer:
streamer.put(next_token.cpu())
if eos_token_id and (next_token == eos_token_id).any():
break
if streamer:
streamer.end()
return input_ids
# ── Loading utilities ──────────────────────────────────────────────
def _load_state_dict(path: Union[str, Path]) -> dict[str, torch.Tensor]:
path = Path(path)
if path.is_dir():
safetensors_path = path / "model.safetensors"
bin_path = path / "pytorch_model.bin"
if safetensors_path.exists():
path = safetensors_path
elif bin_path.exists():
path = bin_path
else:
raise FileNotFoundError(f"no model.safetensors or pytorch_model.bin found in {path}")
if path.suffix == ".safetensors":
return load_safetensors(str(path), device="cpu")
return torch.load(path, map_location="cpu", weights_only=True)
def load_ymodel3_eval(path: Union[str, Path], config: Optional[YConfig3] = None, strict: bool = True) -> YForCausalLM3:
if config is None:
config_path = Path(path) / "config.json" if Path(path).is_dir() else Path(path).with_name("config.json")
if not config_path.exists():
raise FileNotFoundError("config is required when config.json is not next to the checkpoint")
config = YConfig3.from_json_file(str(config_path))
model = YForCausalLM3(config)
state = _load_state_dict(path)
model.load_state_dict(state, strict=strict)
model.eval()
return model
# ── Backward-compatible aliases ────────────────────────────────────
YModel3Eval = YModel3
YForCausalLM3Eval = YForCausalLM3