| """Self-contained ymodel3 inference module. |
| |
| Only depends on: torch, safetensors. |
| No dependency on kernel.*, model.ymodel3, transformers. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from safetensors.torch import load_file as load_safetensors |
|
|
|
|
| |
|
|
|
|
| class YConfig3: |
| model_type = "ynet3" |
|
|
| def __init__(self, **kwargs): |
| self.dropout = kwargs.get("dropout", 0.0) |
| self.bos_token_id = kwargs.get("bos_token_id", 151644) |
| self.eos_token_id = kwargs.get("eos_token_id", 151645) |
| self.pad_token_id = kwargs.get("pad_token_id", 151643) |
| self.hidden_act = kwargs.get("hidden_act", "silu") |
| self.hidden_size = kwargs.get("hidden_size", 768) |
| self.num_hidden_layers = kwargs.get("num_hidden_layers", 8) |
| self.max_position_embeddings = kwargs.get("max_position_embeddings", 8192) |
| self.vocab_size = kwargs.get("vocab_size", 6400) |
| self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6) |
| self.rope_theta = kwargs.get("rope_theta", 5e4) |
| self.rope_scaling = kwargs.get("rope_scaling", None) |
| self.dtype = kwargs.get("dtype", "float32") |
| self.self_distill = kwargs.get("self_distill", True) |
| self.intermediate_size = kwargs.get("intermediate_size", 1536) |
| self.expert_intermediate_size = kwargs.get("expert_intermediate_size", None) or self.intermediate_size |
| self.n_routed_experts = kwargs.get("n_routed_experts", 0) |
| self.moe_topk = kwargs.get("moe_topk", 2) |
| self.score_func = kwargs.get("score_func", "softmax") |
| self.n_shared_experts = kwargs.get("n_shared_experts", 0) |
| self.top_k_layer_dense = kwargs.get("top_k_layer_dense", 1) |
| self.aux_loss_alpha = kwargs.get("aux_loss_alpha", 0.02) |
| self.seq_aux = kwargs.get("seq_aux", False) |
| self.norm_topk_prob = kwargs.get("norm_topk_prob", True) |
| self.noisy_expert = kwargs.get("noisy_expert", 0.0) |
| self.moe_backend = kwargs.get("moe_backend", "compact") |
| self.router_bias_enabled = kwargs.get("router_bias_enabled", True) |
| self.router_bias_update_rate = kwargs.get("router_bias_update_rate", 1e-3) |
| self.router_bias_clamp = kwargs.get("router_bias_clamp", 5.0) |
| self.num_heads = kwargs.get("num_heads", 12) |
| self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 64) |
| self.mla_qk_nope_head_dim = kwargs.get("mla_qk_nope_head_dim", 64) |
| self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32) |
| self.mla_attn_impl = kwargs.get("mla_attn_impl", "absorb") |
| self.qkv_lora = kwargs.get("qkv_lora", False) |
|
|
| @property |
| def head_dim(self) -> int: |
| return self.mla_qk_nope_head_dim + self.mla_qk_rope_head_dim |
|
|
| def scale_lvl(self, lvl: int = 0): |
| if lvl == 0: |
| self.hidden_size = 1024 |
| self.num_hidden_layers = 8 |
| self.num_heads = 8 |
| self.mla_kv_lora_rank = 256 |
| self.mla_qk_nope_head_dim = 192 |
| self.mla_qk_rope_head_dim = 64 |
| self.intermediate_size = 2048 |
| self.expert_intermediate_size = 512 |
| self.n_routed_experts = 16 |
| self.moe_topk = 1 |
| self.n_shared_experts = 0 |
| self.top_k_layer_dense = 1 |
| self.router_bias_update_rate = 1e-3 |
| elif lvl == -1: |
| self.hidden_size = 768 |
| self.num_hidden_layers = 8 |
| self.num_heads = 6 |
| self.mla_kv_lora_rank = 128 |
| self.mla_qk_nope_head_dim = 64 |
| self.mla_qk_rope_head_dim = 64 |
| self.intermediate_size = 1536 |
| self.expert_intermediate_size = 768 |
| self.n_routed_experts = 0 |
| self.moe_topk = 2 |
| self.n_shared_experts = 0 |
| self.top_k_layer_dense = 8 |
| elif lvl == -2: |
| self.hidden_size = 512 |
| self.num_hidden_layers = 4 |
| self.num_heads = 4 |
| self.mla_kv_lora_rank = 128 |
| self.mla_qk_nope_head_dim = 64 |
| self.mla_qk_rope_head_dim = 32 |
| self.intermediate_size = 1024 |
| self.expert_intermediate_size = 512 |
| self.n_routed_experts = 0 |
| self.moe_topk = 2 |
| self.n_shared_experts = 0 |
| self.top_k_layer_dense = 4 |
| else: |
| raise ValueError(f"invalid ymodel3 scale level: {lvl}") |
| return self |
|
|
| @classmethod |
| def from_json_file(cls, path: str) -> "YConfig3": |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| return cls(**data) |
|
|
| @classmethod |
| def from_dict(cls, data: dict) -> "YConfig3": |
| return cls(**data) |
|
|
|
|
| |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) |
| return (out * self.weight.float()).to(x.dtype) |
|
|
|
|
| class SEBlock(nn.Module): |
| def __init__(self, dim: int, reduction: int = 16, act: Optional[nn.Module] = None): |
| super().__init__() |
| reduction = max(reduction, dim // reduction) |
| self.se = nn.Sequential( |
| nn.Linear(dim, reduction, bias=False), |
| act or nn.SiLU(), |
| nn.Linear(reduction, dim, bias=False), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.se(x) |
|
|
|
|
| |
|
|
|
|
| def _yarn_linear_ramp(low: float, high: float, dim: int) -> torch.Tensor: |
| if low == high: |
| high += 0.001 |
| linear = (torch.arange(dim, dtype=torch.float32) - low) / (high - low) |
| return torch.clamp(linear, 0.0, 1.0) |
|
|
|
|
| def _yarn_correction_dim(num_rotations: float, dim: int, theta: float, max_position_embeddings: int) -> float: |
| return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(theta)) |
|
|
|
|
| def precompute_freqs_cis( |
| dim: int, |
| end: int, |
| theta: float, |
| rope_scaling: Optional[dict] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| attention_factor = 1.0 |
| if rope_scaling and str(rope_scaling.get("type", "yarn")).lower() == "yarn": |
| factor = float(rope_scaling.get("factor", 1.0)) |
| if factor > 1.0: |
| original = int(rope_scaling.get("original_max_position_embeddings", end)) |
| beta_fast = float(rope_scaling.get("beta_fast", 32.0)) |
| beta_slow = float(rope_scaling.get("beta_slow", 1.0)) |
| low = math.floor(_yarn_correction_dim(beta_fast, dim, theta, original)) |
| high = math.ceil(_yarn_correction_dim(beta_slow, dim, theta, original)) |
| ramp = _yarn_linear_ramp(low, high, dim // 2) |
| freqs = freqs / factor * (1.0 - ramp) + freqs * ramp |
| attention_factor = float(rope_scaling.get("attention_factor", 1.0)) |
| t = torch.arange(end) |
| freqs = torch.outer(t, freqs).float() |
| freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attention_factor |
| freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attention_factor |
| return freqs_cos, freqs_sin |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1) |
|
|
|
|
| def apply_rope_to_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| if cos.dim() == 2: |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| elif cos.dim() == 3: |
| cos = cos.unsqueeze(1) |
| sin = sin.unsqueeze(1) |
| return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
| |
|
|
|
|
| class MLGA(nn.Module): |
| """Multihead Latent Gated Attention""" |
|
|
| def __init__(self, config: YConfig3, layer_id: int): |
| super().__init__() |
| self.layer_id = layer_id |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_heads |
| self.dropout = config.dropout |
| self.kv_lora_rank = config.mla_kv_lora_rank |
| self.qk_nope_head_dim = config.mla_qk_nope_head_dim |
| self.qk_rope_head_dim = config.mla_qk_rope_head_dim |
| self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim |
| self.attn_impl = config.mla_attn_impl |
| self.softmax_scale = self.qk_head_dim ** -0.5 |
| self.out_dim = self.num_heads * self.kv_lora_rank |
|
|
| self.wq = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False) |
| self.wkv_a = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False) |
| self.kv_norm = RMSNorm(self.kv_lora_rank, config.rms_norm_eps) |
| self.wkv_b = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False) |
| self.z_proj = nn.Linear(self.hidden_size, self.out_dim, bias=False) |
| self.o_proj = nn.Linear(self.out_dim, self.hidden_size, bias=False) |
|
|
| def _project_q(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| bsz, seq_len, _ = x.shape |
| q = self.wq(x) |
| q = q.reshape(bsz, seq_len, self.num_heads, self.qk_head_dim) |
| return q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
| def _project_kv(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| raw = self.wkv_a(x) |
| c_kv, k_pe = raw.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| c_kv = self.kv_norm(c_kv) |
| k_pe = apply_rope_to_single(k_pe.unsqueeze(1), cos, sin).permute(0, 2, 1, 3) |
| return c_kv, k_pe |
|
|
| def _explicit_kv(self, c_kv: torch.Tensor, k_pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| bsz, seq_len, _ = c_kv.shape |
| k_nope = self.wkv_b(c_kv).reshape(bsz, seq_len, self.num_heads, self.qk_nope_head_dim) |
| k = torch.cat([k_nope, k_pe.expand(-1, -1, self.num_heads, -1)], dim=-1) |
| v = c_kv.unsqueeze(2).expand(-1, -1, self.num_heads, -1) |
| return k, v |
|
|
| def _attention_mask(self, attention_mask: Optional[torch.Tensor], bsz: int, seq_len: int, total_len: int): |
| if attention_mask is None: |
| return None |
| if attention_mask.shape[-1] != total_len: |
| attention_mask = attention_mask[..., -total_len:] |
| mask = attention_mask.reshape(bsz, 1, 1, total_len).bool() |
| return mask.expand(bsz, self.num_heads, seq_len, total_len) |
|
|
| def _forward_sdpa( |
| self, |
| q_nope: torch.Tensor, |
| q_pe: torch.Tensor, |
| c_kv: torch.Tensor, |
| k_pe: torch.Tensor, |
| z: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| bsz, seq_len, _, _ = q_nope.shape |
| total_len = c_kv.shape[1] |
| k, v = self._explicit_kv(c_kv, k_pe) |
| q = torch.cat([q_nope, q_pe], dim=-1).permute(0, 2, 1, 3) |
| k = k.permute(0, 2, 1, 3) |
| v = v.permute(0, 2, 1, 3) |
| attn_mask = self._attention_mask(attention_mask, bsz, seq_len, total_len) |
| is_causal = attention_mask is None and seq_len == total_len |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attn_mask, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| scale=self.softmax_scale, |
| ) |
| out = out.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.out_dim) |
| out = out * torch.sigmoid(z) |
| return self.o_proj(out) |
|
|
| def _forward_absorb( |
| self, |
| q_nope: torch.Tensor, |
| q_pe: torch.Tensor, |
| c_kv: torch.Tensor, |
| k_pe: torch.Tensor, |
| z: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| bsz, seq_len, _, _ = q_nope.shape |
| total_len = c_kv.shape[1] |
| w = self.wkv_b.weight.reshape(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) |
| q_nope_c = torch.einsum("bshd,hdc->bshc", q_nope, w) |
| scores = torch.einsum("bshc,btc->bsht", q_nope_c, c_kv) |
| scores = scores + torch.einsum("bshr,btr->bsht", q_pe, k_pe.squeeze(2)) |
| scores = scores * self.softmax_scale |
|
|
| causal = torch.full((seq_len, seq_len), float("-inf"), device=scores.device, dtype=scores.dtype) |
| causal = torch.triu(causal, diagonal=1).reshape(1, seq_len, 1, seq_len) |
| scores = scores + F.pad(causal, (total_len - seq_len, 0), value=0.0) |
| if attention_mask is not None: |
| if attention_mask.shape[-1] != total_len: |
| attention_mask = attention_mask[..., -total_len:] |
| scores = scores + (1.0 - attention_mask.reshape(bsz, 1, 1, total_len).float()) * -1e9 |
| probs = torch.softmax(scores.float(), dim=-1).to(q_nope.dtype) |
| out = torch.einsum("bsht,btc->bshc", probs, c_kv).reshape(bsz, seq_len, self.out_dim) |
| out = out * torch.sigmoid(z) |
| return self.o_proj(out) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: |
| bsz, seq_len, _ = x.shape |
| cos, sin = position_embeddings |
| if cos.dim() == 2: |
| cos = cos[:seq_len, : self.qk_rope_head_dim] |
| sin = sin[:seq_len, : self.qk_rope_head_dim] |
| else: |
| cos = cos[:, :seq_len, : self.qk_rope_head_dim] |
| sin = sin[:, :seq_len, : self.qk_rope_head_dim] |
| q_nope, q_pe = self._project_q(x) |
| q_pe = apply_rope_to_single(q_pe.permute(0, 2, 1, 3), cos, sin).permute(0, 2, 1, 3) |
| c_kv, k_pe = self._project_kv(x, cos, sin) |
| z = self.z_proj(x) |
|
|
| if past_key_values is not None: |
| past_c, past_pe = past_key_values |
| c_kv = torch.cat([past_c, c_kv], dim=1) |
| k_pe = torch.cat([past_pe, k_pe], dim=1) |
| new_past = (c_kv, k_pe) if use_cache else None |
|
|
| if self.attn_impl == "naive": |
| out = self._forward_sdpa(q_nope, q_pe, c_kv, k_pe, z, attention_mask) |
| else: |
| out = self._forward_absorb(q_nope, q_pe, c_kv, k_pe, z, attention_mask) |
| out = F.dropout(out, p=self.dropout, training=self.training) |
| return out, new_past |
|
|
|
|
| |
|
|
|
|
| _ACT_FNS = { |
| "silu": F.silu, |
| "swish": F.silu, |
| "relu": F.relu, |
| "gelu": lambda x: F.gelu(x, approximate="tanh"), |
| "sigmoid": torch.sigmoid, |
| } |
|
|
| _ACT_MODULES = { |
| "silu": nn.SiLU, |
| "swish": nn.SiLU, |
| "relu": nn.ReLU, |
| "gelu": lambda: nn.GELU(approximate="tanh"), |
| "sigmoid": nn.Sigmoid, |
| } |
|
|
|
|
| class DenseFFN(nn.Module): |
| def __init__(self, config: YConfig3, intermediate_size: Optional[int] = None): |
| super().__init__() |
| inter = intermediate_size or config.intermediate_size |
| self.up_proj = nn.Linear(config.hidden_size, inter, bias=False) |
| self.gate_proj = nn.Linear(config.hidden_size, inter, bias=False) |
| self.down_proj = nn.Linear(inter, config.hidden_size, bias=False) |
| self.hidden_act = config.hidden_act |
| self.act = _ACT_FNS.get(config.hidden_act, F.silu) |
| self.dropout = config.dropout |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| up, gate = self.up_proj(x), self.gate_proj(x) |
| up = self.act(gate) * up |
| up = F.dropout(up, p=self.dropout, training=self.training) |
| return self.down_proj(up) |
|
|
|
|
| class MoEGate(nn.Module): |
| def __init__(self, config: YConfig3): |
| super().__init__() |
| self.n_routed_experts = config.n_routed_experts |
| self.topk = min(config.moe_topk, max(1, config.n_routed_experts)) |
| self.score_func = config.score_func |
| self.norm_topk_prob = config.norm_topk_prob |
| self.aux_loss_alpha = config.aux_loss_alpha |
| self.seq_aux = config.seq_aux |
| self.router_bias_enabled = config.router_bias_enabled |
| self.router_bias_update_rate = config.router_bias_update_rate |
| self.router_bias_clamp = config.router_bias_clamp |
| self.weight = nn.Linear(int(config.hidden_size), int(self.n_routed_experts), bias=False) |
| if self.router_bias_enabled: |
| self.register_buffer("router_bias", torch.zeros(self.n_routed_experts), persistent=True) |
| else: |
| self.register_buffer("router_bias", None, persistent=False) |
|
|
| def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None): |
| bsz, seq_len, hidden = x.shape |
| flat = x.reshape(-1, hidden) |
| route_logits = self.weight(flat) |
| if self.score_func == "softmax": |
| route_scores = torch.softmax(route_logits.float(), dim=-1).to(x.dtype) |
| elif self.score_func == "sigmoid": |
| route_scores = torch.sigmoid(route_logits.float()).to(x.dtype) |
| else: |
| raise ValueError(f"unsupported MoE score_func: {self.score_func}") |
|
|
| choice_scores = route_scores |
| if self.router_bias is not None: |
| choice_scores = choice_scores + self.router_bias.to(dtype=choice_scores.dtype).unsqueeze(0) |
|
|
| topk_idx = torch.topk(choice_scores, k=self.topk, dim=-1, sorted=False).indices |
| topk_weight = route_scores.gather(1, topk_idx) |
| if self.topk > 1 and self.norm_topk_prob: |
| denom = topk_weight.float().sum(dim=-1, keepdim=True) + 1e-20 |
| topk_weight = (topk_weight.float() / denom).to(x.dtype) |
|
|
| aux_loss = x.new_zeros((), dtype=x.dtype) |
| return ( |
| topk_idx.reshape(bsz, seq_len, self.topk), |
| topk_weight.reshape(bsz, seq_len, self.topk), |
| aux_loss, |
| ) |
|
|
|
|
| def _torch_moe_swiglu( |
| x: torch.Tensor, |
| topk_idx: torch.Tensor, |
| topk_weight: torch.Tensor, |
| w_up: torch.Tensor, |
| w_down: torch.Tensor, |
| activation: str = "silu", |
| ) -> torch.Tensor: |
| """Pure PyTorch MoE SwiGLU forward (inference only, no noisy_expert).""" |
| original_shape = x.shape |
| x_flat = x.reshape(-1, x.shape[-1]) |
| idx = topk_idx.reshape(x_flat.shape[0], -1) |
| weight = topk_weight.reshape(x_flat.shape[0], -1) |
| y = torch.zeros_like(x_flat) |
| n_experts = w_up.shape[0] |
| inter = w_down.shape[-1] |
| act_fn = _ACT_FNS.get(activation, F.silu) |
| for expert_id in range(n_experts): |
| token_pos, choice_pos = torch.where(idx == expert_id) |
| if token_pos.numel() == 0: |
| continue |
| inp = x_flat[token_pos] |
| uv = F.linear(inp, w_up[expert_id]) |
| up, gate = uv.split(inter, dim=-1) |
| hidden = act_fn(gate) * up |
| out = F.linear(hidden, w_down[expert_id]) |
| route_w = weight[token_pos, choice_pos].unsqueeze(-1) |
| y.index_add_(0, token_pos, out * route_w) |
| return y.reshape(original_shape) |
|
|
|
|
| class YMoE(nn.Module): |
| """Pure PyTorch eval MoE (no Triton dependency).""" |
|
|
| def __init__(self, config: YConfig3, layer_id: int): |
| super().__init__() |
| self.layer_id = layer_id |
| self.hidden_size = config.hidden_size |
| self.expert_intermediate_size = config.expert_intermediate_size |
| self.intermediate_size = self.expert_intermediate_size |
| self.n_routed_experts = config.n_routed_experts |
| self.use_moe = self.n_routed_experts > 0 and layer_id >= config.top_k_layer_dense |
| self.noisy_expert = config.noisy_expert |
| if not self.use_moe: |
| self.dense = DenseFFN(config) |
| self.gate = None |
| self.w_up = None |
| self.w_down = None |
| return |
| self.dense = None |
| self.gate = MoEGate(config) |
| self.w_up = nn.Parameter(torch.empty(self.n_routed_experts, 2 * self.expert_intermediate_size, self.hidden_size)) |
| self.w_down = nn.Parameter(torch.empty(self.n_routed_experts, self.hidden_size, self.expert_intermediate_size)) |
| nn.init.kaiming_uniform_(self.w_up, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.w_down, a=math.sqrt(5)) |
|
|
| def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None): |
| if not self.use_moe: |
| return self.dense(x), None |
| topk_idx, topk_weight, aux_loss = self.gate(x, aux_mask) |
| y = _torch_moe_swiglu(x, topk_idx, topk_weight, self.w_up, self.w_down, activation="silu") |
| return y, aux_loss |
|
|
|
|
| |
|
|
|
|
| class YBlock3(nn.Module): |
| def __init__(self, config: YConfig3, layer_id: int): |
| super().__init__() |
| self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.attn = MLGA(config, layer_id) |
| self.ffn = YMoE(config, layer_id) |
| act_module = _ACT_MODULES.get(config.hidden_act, nn.SiLU) |
| self.se1 = SEBlock(config.hidden_size, act=act_module() if isinstance(act_module, type) else act_module()) |
| self.se2 = SEBlock(config.hidden_size, act=nn.SiLU()) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| past_key_values=None, |
| use_cache: bool = False, |
| attention_mask: Optional[torch.Tensor] = None, |
| aux_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| x0 = self.se1(self.input_layernorm(x)) |
| attn_out, past = self.attn( |
| x0, |
| position_embeddings, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| ) |
| x = x + attn_out |
| x0 = self.se2(self.post_attention_layernorm(x)) |
| ffn_out, aux_loss = self.ffn(x0, aux_mask) |
| x = x + ffn_out |
| return x, past, aux_loss |
|
|
|
|
| |
|
|
|
|
| class YModel3(nn.Module): |
| def __init__(self, config: YConfig3): |
| super().__init__() |
| self.config = config |
| self.vocab_size = config.vocab_size |
| self.num_layers = config.num_hidden_layers |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.dropout = config.dropout |
| self.use_self_distill = config.self_distill |
| self.layers = nn.ModuleList([YBlock3(config, i) for i in range(config.num_hidden_layers)]) |
| self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) |
| freqs_cos, freqs_sin = precompute_freqs_cis( |
| dim=config.mla_qk_rope_head_dim, |
| end=config.max_position_embeddings, |
| theta=config.rope_theta, |
| rope_scaling=config.rope_scaling, |
| ) |
| self.register_buffer("freqs_cos", freqs_cos, persistent=False) |
| self.register_buffer("freqs_sin", freqs_sin, persistent=False) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[list] = None, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| bsz, seq_len = input_ids.shape |
| if use_cache and past_key_values is None: |
| past_key_values = [None] * self.num_layers |
| if cache_position is None: |
| if past_key_values is not None and past_key_values[0] is not None: |
| past_seen = past_key_values[0][0].shape[1] |
| else: |
| past_seen = 0 |
| cache_position = torch.arange(past_seen, past_seen + seq_len, device=input_ids.device) |
|
|
| x = F.dropout(self.embed_tokens(input_ids), p=self.dropout, training=self.training) |
| if position_ids is None: |
| position_ids = cache_position |
| position_embeddings = (self.freqs_cos[position_ids].to(x.device), self.freqs_sin[position_ids].to(x.device)) |
| aux_mask = None |
| new_past = [] if use_cache else None |
| aux_loss = None |
|
|
| for i, layer in enumerate(self.layers): |
| past = past_key_values[i] if past_key_values is not None else None |
| x, layer_past, layer_aux = layer( |
| x, |
| position_embeddings=position_embeddings, |
| past_key_values=past, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| aux_mask=aux_mask, |
| ) |
| if use_cache: |
| new_past.append(layer_past) |
| if self.training and layer_aux is not None: |
| aux_loss = layer_aux if aux_loss is None else aux_loss + layer_aux |
|
|
| return self.norm(x), new_past, None, aux_loss |
|
|
|
|
| class _InferenceOutput: |
| """Simple container for model outputs (replaces transformers CausalLMOutputWithPast).""" |
|
|
| __slots__ = ("last_hidden_state", "logits", "past_key_values", "dist_loss", "aux_loss") |
|
|
| def __init__(self): |
| self.last_hidden_state = None |
| self.logits = None |
| self.past_key_values = None |
| self.dist_loss = None |
| self.aux_loss = None |
|
|
| def __setitem__(self, key, value): |
| setattr(self, key, value) |
|
|
|
|
| class YForCausalLM3(nn.Module): |
| """Pure PyTorch CausalLM wrapper for ymodel3 inference (no transformers dependency).""" |
|
|
| config_class = YConfig3 |
|
|
| def __init__(self, config: Optional[YConfig3] = None): |
| super().__init__() |
| self.config = config or YConfig3() |
| self.model = YModel3(self.config) |
| self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) |
| self.model.embed_tokens.weight = self.lm_head.weight |
| self.OUT = _InferenceOutput() |
| dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}.get(self.config.dtype) |
| if dtype is not None: |
| self.to(dtype) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[list] = None, |
| use_cache: bool = False, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| h, past_kvs, dist_loss, aux_loss = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_ids=kwargs.get("position_ids", None), |
| ) |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(h[:, slice_indices, :]) |
| self.OUT.__setitem__("last_hidden_state", h) |
| self.OUT.__setitem__("logits", logits) |
| self.OUT.__setitem__("past_key_values", past_kvs) |
| self.OUT.__setitem__("dist_loss", dist_loss) |
| self.OUT.__setitem__("aux_loss", aux_loss) |
| return self.OUT |
|
|
| def generate( |
| self, |
| inputs, |
| attention_mask=None, |
| max_new_tokens=8192, |
| temperature=0.85, |
| top_p=0.85, |
| top_k=50, |
| eos_token_id=None, |
| streamer=None, |
| use_cache=True, |
| num_return_sequences=1, |
| do_sample=True, |
| repetition_penalty=1.0, |
| **kwargs, |
| ): |
| input_ids = kwargs.get("input_ids", inputs).repeat(num_return_sequences, 1) |
| attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None |
| past_key_values = None |
| if streamer: |
| streamer.put(input_ids.cpu()) |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| if use_cache and past_key_values is not None: |
| outputs = self.forward(input_ids[:, -1:], None, past_key_values, use_cache=use_cache) |
| else: |
| outputs = self.forward(input_ids, attention_mask, past_key_values, use_cache=use_cache) |
| logits = outputs.logits[:, -1, :] / temperature |
| if repetition_penalty != 1.0: |
| for i in range(input_ids.shape[0]): |
| logits[i, torch.unique(input_ids[i])] /= repetition_penalty |
| if top_k > 0: |
| logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float("inf") |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p |
| mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0 |
| logits[mask.scatter(1, sorted_indices, mask)] = -float("inf") |
| next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True) |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
| past_key_values = outputs.past_key_values if use_cache else None |
| if streamer: |
| streamer.put(next_token.cpu()) |
| if eos_token_id and (next_token == eos_token_id).any(): |
| break |
| if streamer: |
| streamer.end() |
| return input_ids |
|
|
|
|
| |
|
|
|
|
| def _load_state_dict(path: Union[str, Path]) -> dict[str, torch.Tensor]: |
| path = Path(path) |
| if path.is_dir(): |
| safetensors_path = path / "model.safetensors" |
| bin_path = path / "pytorch_model.bin" |
| if safetensors_path.exists(): |
| path = safetensors_path |
| elif bin_path.exists(): |
| path = bin_path |
| else: |
| raise FileNotFoundError(f"no model.safetensors or pytorch_model.bin found in {path}") |
| if path.suffix == ".safetensors": |
| return load_safetensors(str(path), device="cpu") |
| return torch.load(path, map_location="cpu", weights_only=True) |
|
|
|
|
| def load_ymodel3_eval(path: Union[str, Path], config: Optional[YConfig3] = None, strict: bool = True) -> YForCausalLM3: |
| if config is None: |
| config_path = Path(path) / "config.json" if Path(path).is_dir() else Path(path).with_name("config.json") |
| if not config_path.exists(): |
| raise FileNotFoundError("config is required when config.json is not next to the checkpoint") |
| config = YConfig3.from_json_file(str(config_path)) |
| model = YForCausalLM3(config) |
| state = _load_state_dict(path) |
| model.load_state_dict(state, strict=strict) |
| model.eval() |
| return model |
|
|
|
|
| |
|
|
| YModel3Eval = YModel3 |
| YForCausalLM3Eval = YForCausalLM3 |
|
|