"""Self-contained ymodel3 inference module. Only depends on: torch, safetensors. No dependency on kernel.*, model.ymodel3, transformers. """ from __future__ import annotations import json import math from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file as load_safetensors # ── Config ────────────────────────────────────────────────────────── class YConfig3: model_type = "ynet3" def __init__(self, **kwargs): self.dropout = kwargs.get("dropout", 0.0) self.bos_token_id = kwargs.get("bos_token_id", 151644) self.eos_token_id = kwargs.get("eos_token_id", 151645) self.pad_token_id = kwargs.get("pad_token_id", 151643) self.hidden_act = kwargs.get("hidden_act", "silu") self.hidden_size = kwargs.get("hidden_size", 768) self.num_hidden_layers = kwargs.get("num_hidden_layers", 8) self.max_position_embeddings = kwargs.get("max_position_embeddings", 8192) self.vocab_size = kwargs.get("vocab_size", 6400) self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6) self.rope_theta = kwargs.get("rope_theta", 5e4) self.rope_scaling = kwargs.get("rope_scaling", None) self.dtype = kwargs.get("dtype", "float32") self.self_distill = kwargs.get("self_distill", True) self.intermediate_size = kwargs.get("intermediate_size", 1536) self.expert_intermediate_size = kwargs.get("expert_intermediate_size", None) or self.intermediate_size self.n_routed_experts = kwargs.get("n_routed_experts", 0) self.moe_topk = kwargs.get("moe_topk", 2) self.score_func = kwargs.get("score_func", "softmax") self.n_shared_experts = kwargs.get("n_shared_experts", 0) self.top_k_layer_dense = kwargs.get("top_k_layer_dense", 1) self.aux_loss_alpha = kwargs.get("aux_loss_alpha", 0.02) self.seq_aux = kwargs.get("seq_aux", False) self.norm_topk_prob = kwargs.get("norm_topk_prob", True) self.noisy_expert = kwargs.get("noisy_expert", 0.0) self.moe_backend = kwargs.get("moe_backend", "compact") self.router_bias_enabled = kwargs.get("router_bias_enabled", True) self.router_bias_update_rate = kwargs.get("router_bias_update_rate", 1e-3) self.router_bias_clamp = kwargs.get("router_bias_clamp", 5.0) self.num_heads = kwargs.get("num_heads", 12) self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 64) self.mla_qk_nope_head_dim = kwargs.get("mla_qk_nope_head_dim", 64) self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32) self.mla_attn_impl = kwargs.get("mla_attn_impl", "absorb") self.qkv_lora = kwargs.get("qkv_lora", False) @property def head_dim(self) -> int: return self.mla_qk_nope_head_dim + self.mla_qk_rope_head_dim def scale_lvl(self, lvl: int = 0): if lvl == 0: self.hidden_size = 1024 self.num_hidden_layers = 8 self.num_heads = 8 self.mla_kv_lora_rank = 256 self.mla_qk_nope_head_dim = 192 self.mla_qk_rope_head_dim = 64 self.intermediate_size = 2048 self.expert_intermediate_size = 512 self.n_routed_experts = 16 self.moe_topk = 1 self.n_shared_experts = 0 self.top_k_layer_dense = 1 self.router_bias_update_rate = 1e-3 elif lvl == -1: self.hidden_size = 768 self.num_hidden_layers = 8 self.num_heads = 6 self.mla_kv_lora_rank = 128 self.mla_qk_nope_head_dim = 64 self.mla_qk_rope_head_dim = 64 self.intermediate_size = 1536 self.expert_intermediate_size = 768 self.n_routed_experts = 0 self.moe_topk = 2 self.n_shared_experts = 0 self.top_k_layer_dense = 8 elif lvl == -2: self.hidden_size = 512 self.num_hidden_layers = 4 self.num_heads = 4 self.mla_kv_lora_rank = 128 self.mla_qk_nope_head_dim = 64 self.mla_qk_rope_head_dim = 32 self.intermediate_size = 1024 self.expert_intermediate_size = 512 self.n_routed_experts = 0 self.moe_topk = 2 self.n_shared_experts = 0 self.top_k_layer_dense = 4 else: raise ValueError(f"invalid ymodel3 scale level: {lvl}") return self @classmethod def from_json_file(cls, path: str) -> "YConfig3": with open(path, "r", encoding="utf-8") as f: data = json.load(f) return cls(**data) @classmethod def from_dict(cls, data: dict) -> "YConfig3": return cls(**data) # ── Basic modules ────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: out = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return (out * self.weight.float()).to(x.dtype) class SEBlock(nn.Module): def __init__(self, dim: int, reduction: int = 16, act: Optional[nn.Module] = None): super().__init__() reduction = max(reduction, dim // reduction) self.se = nn.Sequential( nn.Linear(dim, reduction, bias=False), act or nn.SiLU(), nn.Linear(reduction, dim, bias=False), nn.Sigmoid(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.se(x) # ── RoPE helpers ────────────────────────────────────────────────── def _yarn_linear_ramp(low: float, high: float, dim: int) -> torch.Tensor: if low == high: high += 0.001 linear = (torch.arange(dim, dtype=torch.float32) - low) / (high - low) return torch.clamp(linear, 0.0, 1.0) def _yarn_correction_dim(num_rotations: float, dim: int, theta: float, max_position_embeddings: int) -> float: return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(theta)) def precompute_freqs_cis( dim: int, end: int, theta: float, rope_scaling: Optional[dict] = None, ) -> tuple[torch.Tensor, torch.Tensor]: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) attention_factor = 1.0 if rope_scaling and str(rope_scaling.get("type", "yarn")).lower() == "yarn": factor = float(rope_scaling.get("factor", 1.0)) if factor > 1.0: original = int(rope_scaling.get("original_max_position_embeddings", end)) beta_fast = float(rope_scaling.get("beta_fast", 32.0)) beta_slow = float(rope_scaling.get("beta_slow", 1.0)) low = math.floor(_yarn_correction_dim(beta_fast, dim, theta, original)) high = math.ceil(_yarn_correction_dim(beta_slow, dim, theta, original)) ramp = _yarn_linear_ramp(low, high, dim // 2) freqs = freqs / factor * (1.0 - ramp) + freqs * ramp attention_factor = float(rope_scaling.get("attention_factor", 1.0)) t = torch.arange(end) freqs = torch.outer(t, freqs).float() freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attention_factor freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attention_factor return freqs_cos, freqs_sin def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1) def apply_rope_to_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: if cos.dim() == 2: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) elif cos.dim() == 3: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) return (x * cos) + (rotate_half(x) * sin) # ── Attention ────────────────────────────────────────────────────── class MLGA(nn.Module): """Multihead Latent Gated Attention""" def __init__(self, config: YConfig3, layer_id: int): super().__init__() self.layer_id = layer_id self.hidden_size = config.hidden_size self.num_heads = config.num_heads self.dropout = config.dropout self.kv_lora_rank = config.mla_kv_lora_rank self.qk_nope_head_dim = config.mla_qk_nope_head_dim self.qk_rope_head_dim = config.mla_qk_rope_head_dim self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.attn_impl = config.mla_attn_impl self.softmax_scale = self.qk_head_dim ** -0.5 self.out_dim = self.num_heads * self.kv_lora_rank self.wq = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False) self.wkv_a = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False) self.kv_norm = RMSNorm(self.kv_lora_rank, config.rms_norm_eps) self.wkv_b = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False) self.z_proj = nn.Linear(self.hidden_size, self.out_dim, bias=False) self.o_proj = nn.Linear(self.out_dim, self.hidden_size, bias=False) def _project_q(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: bsz, seq_len, _ = x.shape q = self.wq(x) q = q.reshape(bsz, seq_len, self.num_heads, self.qk_head_dim) return q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) def _project_kv(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: raw = self.wkv_a(x) c_kv, k_pe = raw.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) c_kv = self.kv_norm(c_kv) k_pe = apply_rope_to_single(k_pe.unsqueeze(1), cos, sin).permute(0, 2, 1, 3) return c_kv, k_pe def _explicit_kv(self, c_kv: torch.Tensor, k_pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: bsz, seq_len, _ = c_kv.shape k_nope = self.wkv_b(c_kv).reshape(bsz, seq_len, self.num_heads, self.qk_nope_head_dim) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.num_heads, -1)], dim=-1) v = c_kv.unsqueeze(2).expand(-1, -1, self.num_heads, -1) return k, v def _attention_mask(self, attention_mask: Optional[torch.Tensor], bsz: int, seq_len: int, total_len: int): if attention_mask is None: return None if attention_mask.shape[-1] != total_len: attention_mask = attention_mask[..., -total_len:] mask = attention_mask.reshape(bsz, 1, 1, total_len).bool() return mask.expand(bsz, self.num_heads, seq_len, total_len) def _forward_sdpa( self, q_nope: torch.Tensor, q_pe: torch.Tensor, c_kv: torch.Tensor, k_pe: torch.Tensor, z: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: bsz, seq_len, _, _ = q_nope.shape total_len = c_kv.shape[1] k, v = self._explicit_kv(c_kv, k_pe) q = torch.cat([q_nope, q_pe], dim=-1).permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn_mask = self._attention_mask(attention_mask, bsz, seq_len, total_len) is_causal = attention_mask is None and seq_len == total_len out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, scale=self.softmax_scale, ) out = out.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.out_dim) out = out * torch.sigmoid(z) return self.o_proj(out) def _forward_absorb( self, q_nope: torch.Tensor, q_pe: torch.Tensor, c_kv: torch.Tensor, k_pe: torch.Tensor, z: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: bsz, seq_len, _, _ = q_nope.shape total_len = c_kv.shape[1] w = self.wkv_b.weight.reshape(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) q_nope_c = torch.einsum("bshd,hdc->bshc", q_nope, w) scores = torch.einsum("bshc,btc->bsht", q_nope_c, c_kv) scores = scores + torch.einsum("bshr,btr->bsht", q_pe, k_pe.squeeze(2)) scores = scores * self.softmax_scale causal = torch.full((seq_len, seq_len), float("-inf"), device=scores.device, dtype=scores.dtype) causal = torch.triu(causal, diagonal=1).reshape(1, seq_len, 1, seq_len) scores = scores + F.pad(causal, (total_len - seq_len, 0), value=0.0) if attention_mask is not None: if attention_mask.shape[-1] != total_len: attention_mask = attention_mask[..., -total_len:] scores = scores + (1.0 - attention_mask.reshape(bsz, 1, 1, total_len).float()) * -1e9 probs = torch.softmax(scores.float(), dim=-1).to(q_nope.dtype) out = torch.einsum("bsht,btc->bshc", probs, c_kv).reshape(bsz, seq_len, self.out_dim) out = out * torch.sigmoid(z) return self.o_proj(out) def forward( self, x: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, **kwargs, ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: bsz, seq_len, _ = x.shape cos, sin = position_embeddings if cos.dim() == 2: cos = cos[:seq_len, : self.qk_rope_head_dim] sin = sin[:seq_len, : self.qk_rope_head_dim] else: cos = cos[:, :seq_len, : self.qk_rope_head_dim] sin = sin[:, :seq_len, : self.qk_rope_head_dim] q_nope, q_pe = self._project_q(x) q_pe = apply_rope_to_single(q_pe.permute(0, 2, 1, 3), cos, sin).permute(0, 2, 1, 3) c_kv, k_pe = self._project_kv(x, cos, sin) z = self.z_proj(x) if past_key_values is not None: past_c, past_pe = past_key_values c_kv = torch.cat([past_c, c_kv], dim=1) k_pe = torch.cat([past_pe, k_pe], dim=1) new_past = (c_kv, k_pe) if use_cache else None if self.attn_impl == "naive": out = self._forward_sdpa(q_nope, q_pe, c_kv, k_pe, z, attention_mask) else: out = self._forward_absorb(q_nope, q_pe, c_kv, k_pe, z, attention_mask) out = F.dropout(out, p=self.dropout, training=self.training) return out, new_past # ── FFN / MoE ────────────────────────────────────────────────────── _ACT_FNS = { "silu": F.silu, "swish": F.silu, "relu": F.relu, "gelu": lambda x: F.gelu(x, approximate="tanh"), "sigmoid": torch.sigmoid, } _ACT_MODULES = { "silu": nn.SiLU, "swish": nn.SiLU, "relu": nn.ReLU, "gelu": lambda: nn.GELU(approximate="tanh"), "sigmoid": nn.Sigmoid, } class DenseFFN(nn.Module): def __init__(self, config: YConfig3, intermediate_size: Optional[int] = None): super().__init__() inter = intermediate_size or config.intermediate_size self.up_proj = nn.Linear(config.hidden_size, inter, bias=False) self.gate_proj = nn.Linear(config.hidden_size, inter, bias=False) self.down_proj = nn.Linear(inter, config.hidden_size, bias=False) self.hidden_act = config.hidden_act self.act = _ACT_FNS.get(config.hidden_act, F.silu) self.dropout = config.dropout def forward(self, x: torch.Tensor) -> torch.Tensor: up, gate = self.up_proj(x), self.gate_proj(x) up = self.act(gate) * up up = F.dropout(up, p=self.dropout, training=self.training) return self.down_proj(up) class MoEGate(nn.Module): def __init__(self, config: YConfig3): super().__init__() self.n_routed_experts = config.n_routed_experts self.topk = min(config.moe_topk, max(1, config.n_routed_experts)) self.score_func = config.score_func self.norm_topk_prob = config.norm_topk_prob self.aux_loss_alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.router_bias_enabled = config.router_bias_enabled self.router_bias_update_rate = config.router_bias_update_rate self.router_bias_clamp = config.router_bias_clamp self.weight = nn.Linear(int(config.hidden_size), int(self.n_routed_experts), bias=False) if self.router_bias_enabled: self.register_buffer("router_bias", torch.zeros(self.n_routed_experts), persistent=True) else: self.register_buffer("router_bias", None, persistent=False) def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None): bsz, seq_len, hidden = x.shape flat = x.reshape(-1, hidden) route_logits = self.weight(flat) if self.score_func == "softmax": route_scores = torch.softmax(route_logits.float(), dim=-1).to(x.dtype) elif self.score_func == "sigmoid": route_scores = torch.sigmoid(route_logits.float()).to(x.dtype) else: raise ValueError(f"unsupported MoE score_func: {self.score_func}") choice_scores = route_scores if self.router_bias is not None: choice_scores = choice_scores + self.router_bias.to(dtype=choice_scores.dtype).unsqueeze(0) topk_idx = torch.topk(choice_scores, k=self.topk, dim=-1, sorted=False).indices topk_weight = route_scores.gather(1, topk_idx) if self.topk > 1 and self.norm_topk_prob: denom = topk_weight.float().sum(dim=-1, keepdim=True) + 1e-20 topk_weight = (topk_weight.float() / denom).to(x.dtype) aux_loss = x.new_zeros((), dtype=x.dtype) return ( topk_idx.reshape(bsz, seq_len, self.topk), topk_weight.reshape(bsz, seq_len, self.topk), aux_loss, ) def _torch_moe_swiglu( x: torch.Tensor, topk_idx: torch.Tensor, topk_weight: torch.Tensor, w_up: torch.Tensor, w_down: torch.Tensor, activation: str = "silu", ) -> torch.Tensor: """Pure PyTorch MoE SwiGLU forward (inference only, no noisy_expert).""" original_shape = x.shape x_flat = x.reshape(-1, x.shape[-1]) idx = topk_idx.reshape(x_flat.shape[0], -1) weight = topk_weight.reshape(x_flat.shape[0], -1) y = torch.zeros_like(x_flat) n_experts = w_up.shape[0] inter = w_down.shape[-1] act_fn = _ACT_FNS.get(activation, F.silu) for expert_id in range(n_experts): token_pos, choice_pos = torch.where(idx == expert_id) if token_pos.numel() == 0: continue inp = x_flat[token_pos] uv = F.linear(inp, w_up[expert_id]) up, gate = uv.split(inter, dim=-1) hidden = act_fn(gate) * up out = F.linear(hidden, w_down[expert_id]) route_w = weight[token_pos, choice_pos].unsqueeze(-1) y.index_add_(0, token_pos, out * route_w) return y.reshape(original_shape) class YMoE(nn.Module): """Pure PyTorch eval MoE (no Triton dependency).""" def __init__(self, config: YConfig3, layer_id: int): super().__init__() self.layer_id = layer_id self.hidden_size = config.hidden_size self.expert_intermediate_size = config.expert_intermediate_size self.intermediate_size = self.expert_intermediate_size self.n_routed_experts = config.n_routed_experts self.use_moe = self.n_routed_experts > 0 and layer_id >= config.top_k_layer_dense self.noisy_expert = config.noisy_expert if not self.use_moe: self.dense = DenseFFN(config) self.gate = None self.w_up = None self.w_down = None return self.dense = None self.gate = MoEGate(config) self.w_up = nn.Parameter(torch.empty(self.n_routed_experts, 2 * self.expert_intermediate_size, self.hidden_size)) self.w_down = nn.Parameter(torch.empty(self.n_routed_experts, self.hidden_size, self.expert_intermediate_size)) nn.init.kaiming_uniform_(self.w_up, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.w_down, a=math.sqrt(5)) def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None): if not self.use_moe: return self.dense(x), None topk_idx, topk_weight, aux_loss = self.gate(x, aux_mask) y = _torch_moe_swiglu(x, topk_idx, topk_weight, self.w_up, self.w_down, activation="silu") return y, aux_loss # ── Transformer block ────────────────────────────────────────────── class YBlock3(nn.Module): def __init__(self, config: YConfig3, layer_id: int): super().__init__() self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.attn = MLGA(config, layer_id) self.ffn = YMoE(config, layer_id) act_module = _ACT_MODULES.get(config.hidden_act, nn.SiLU) self.se1 = SEBlock(config.hidden_size, act=act_module() if isinstance(act_module, type) else act_module()) self.se2 = SEBlock(config.hidden_size, act=nn.SiLU()) def forward( self, x: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], past_key_values=None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None, aux_mask: Optional[torch.Tensor] = None, **kwargs, ): x0 = self.se1(self.input_layernorm(x)) attn_out, past = self.attn( x0, position_embeddings, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=use_cache, ) x = x + attn_out x0 = self.se2(self.post_attention_layernorm(x)) ffn_out, aux_loss = self.ffn(x0, aux_mask) x = x + ffn_out return x, past, aux_loss # ── Full model ──────────────────────────────────────────────────── class YModel3(nn.Module): def __init__(self, config: YConfig3): super().__init__() self.config = config self.vocab_size = config.vocab_size self.num_layers = config.num_hidden_layers self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.dropout = config.dropout self.use_self_distill = config.self_distill self.layers = nn.ModuleList([YBlock3(config, i) for i in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) freqs_cos, freqs_sin = precompute_freqs_cis( dim=config.mla_qk_rope_head_dim, end=config.max_position_embeddings, theta=config.rope_theta, rope_scaling=config.rope_scaling, ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ): bsz, seq_len = input_ids.shape if use_cache and past_key_values is None: past_key_values = [None] * self.num_layers if cache_position is None: if past_key_values is not None and past_key_values[0] is not None: past_seen = past_key_values[0][0].shape[1] else: past_seen = 0 cache_position = torch.arange(past_seen, past_seen + seq_len, device=input_ids.device) x = F.dropout(self.embed_tokens(input_ids), p=self.dropout, training=self.training) if position_ids is None: position_ids = cache_position position_embeddings = (self.freqs_cos[position_ids].to(x.device), self.freqs_sin[position_ids].to(x.device)) aux_mask = None new_past = [] if use_cache else None aux_loss = None for i, layer in enumerate(self.layers): past = past_key_values[i] if past_key_values is not None else None x, layer_past, layer_aux = layer( x, position_embeddings=position_embeddings, past_key_values=past, attention_mask=attention_mask, use_cache=use_cache, aux_mask=aux_mask, ) if use_cache: new_past.append(layer_past) if self.training and layer_aux is not None: aux_loss = layer_aux if aux_loss is None else aux_loss + layer_aux return self.norm(x), new_past, None, aux_loss class _InferenceOutput: """Simple container for model outputs (replaces transformers CausalLMOutputWithPast).""" __slots__ = ("last_hidden_state", "logits", "past_key_values", "dist_loss", "aux_loss") def __init__(self): self.last_hidden_state = None self.logits = None self.past_key_values = None self.dist_loss = None self.aux_loss = None def __setitem__(self, key, value): setattr(self, key, value) class YForCausalLM3(nn.Module): """Pure PyTorch CausalLM wrapper for ymodel3 inference (no transformers dependency).""" config_class = YConfig3 def __init__(self, config: Optional[YConfig3] = None): super().__init__() self.config = config or YConfig3() self.model = YModel3(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.model.embed_tokens.weight = self.lm_head.weight self.OUT = _InferenceOutput() dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}.get(self.config.dtype) if dtype is not None: self.to(dtype) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list] = None, use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): h, past_kvs, dist_loss, aux_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_ids=kwargs.get("position_ids", None), ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(h[:, slice_indices, :]) self.OUT.__setitem__("last_hidden_state", h) self.OUT.__setitem__("logits", logits) self.OUT.__setitem__("past_key_values", past_kvs) self.OUT.__setitem__("dist_loss", dist_loss) self.OUT.__setitem__("aux_loss", aux_loss) return self.OUT def generate( self, inputs, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=None, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs, ): input_ids = kwargs.get("input_ids", inputs).repeat(num_return_sequences, 1) attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None past_key_values = None if streamer: streamer.put(input_ids.cpu()) with torch.no_grad(): for _ in range(max_new_tokens): if use_cache and past_key_values is not None: outputs = self.forward(input_ids[:, -1:], None, past_key_values, use_cache=use_cache) else: outputs = self.forward(input_ids, attention_mask, past_key_values, use_cache=use_cache) logits = outputs.logits[:, -1, :] / temperature if repetition_penalty != 1.0: for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty if top_k > 0: logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float("inf") if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0 logits[mask.scatter(1, sorted_indices, mask)] = -float("inf") next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=-1) past_key_values = outputs.past_key_values if use_cache else None if streamer: streamer.put(next_token.cpu()) if eos_token_id and (next_token == eos_token_id).any(): break if streamer: streamer.end() return input_ids # ── Loading utilities ────────────────────────────────────────────── def _load_state_dict(path: Union[str, Path]) -> dict[str, torch.Tensor]: path = Path(path) if path.is_dir(): safetensors_path = path / "model.safetensors" bin_path = path / "pytorch_model.bin" if safetensors_path.exists(): path = safetensors_path elif bin_path.exists(): path = bin_path else: raise FileNotFoundError(f"no model.safetensors or pytorch_model.bin found in {path}") if path.suffix == ".safetensors": return load_safetensors(str(path), device="cpu") return torch.load(path, map_location="cpu", weights_only=True) def load_ymodel3_eval(path: Union[str, Path], config: Optional[YConfig3] = None, strict: bool = True) -> YForCausalLM3: if config is None: config_path = Path(path) / "config.json" if Path(path).is_dir() else Path(path).with_name("config.json") if not config_path.exists(): raise FileNotFoundError("config is required when config.json is not next to the checkpoint") config = YConfig3.from_json_file(str(config_path)) model = YForCausalLM3(config) state = _load_state_dict(path) model.load_state_dict(state, strict=strict) model.eval() return model # ── Backward-compatible aliases ──────────────────────────────────── YModel3Eval = YModel3 YForCausalLM3Eval = YForCausalLM3