import math import torch import torch.nn as nn from typing import Optional, Tuple, Union, List from transformers import PreTrainedModel, GenerationMixin from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.configuration_utils import PretrainedConfig class YConfig1_1(PretrainedConfig): model_type = "ynet" def __init__( self, dropout: float = 0.1, bos_token_id: int = 1, eos_token_id: int = 2, hidden_act: str = 'gelu_pytorch_tanh', exp: float = 3.0, ffn_shared: int = 3, hidden_size: int = 512, intermediate_size: int = None, max_position_embeddings: int = 8192, num_heads: int = 8, num_layers: int = 9, pe_dim: int = 64, head_dim: int = 64, groups: int = 4, vocab_size: int = 6400, rms_norm_eps: float = 1e-7, rope_theta: int = 5e4, flash_attn: bool = True, self_distill: bool = True, **kwargs ): super().__init__(**kwargs) self.dropout = dropout self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.hidden_act = hidden_act self.exp = exp # ffn 扩张倍率 self.ffn_shared = ffn_shared # ffn up & down权重共享层数 self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.num_heads = num_heads # q头数 self.num_layers = num_layers # 层数 self.pe_dim = pe_dim # 位置嵌入头数 self.head_dim = head_dim # 头维度 self.groups = groups # GQA每个分组的头数 self.vocab_size = vocab_size self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.flash_attn = flash_attn self.self_distill = self_distill def scale_lvl(self, lvl:int=0): if lvl == 0: # normal settings [80.27m] self.exp = 3.0 self.ffn_shared = 3 self.hidden_size = 512 self.num_heads = 12 self.num_layers = 27 self.pe_dim = 96 self.head_dim = 64 self.groups = 6 elif lvl == -1: # small -1 [24m] self.exp = 3.0 self.ffn_shared = 3 self.hidden_size = 512 self.num_heads = 8 self.num_layers = 12 self.pe_dim = 64 self.head_dim = 64 self.groups = 8 elif lvl == -2: # small -2 [12m] self.exp = 2.0 self.ffn_shared = 4 self.hidden_size = 512 self.num_heads = 7 self.num_layers = 8 self.pe_dim = 48 self.head_dim = 48 self.groups = 6 elif lvl == -3: # small -3 [6m] self.exp = 2.0 self.ffn_shared = 3 self.hidden_size = 384 self.num_heads = 7 self.num_layers = 6 self.pe_dim = 48 self.head_dim = 32 self.groups = 6 ######## large ####### elif lvl == 1: # large +1 [0.2b] self.exp = 2.0 self.ffn_shared = 3 self.hidden_size = 768 self.num_heads = 12 self.num_layers = 24 self.pe_dim = 96 self.head_dim = 64 self.groups = 6 elif lvl == 2: # large +2 [0.6b] self.exp = 3.0 self.ffn_shared = 3 self.hidden_size = 1344 self.num_heads = 25 self.num_layers = 24 self.pe_dim = 192 self.head_dim = 96 self.groups = 7 else: raise ValueError(f"Invalid level: {lvl}") class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) output = output * self.weight.float() return output.type_as(x) def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 5e4): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) return freqs_cos, freqs_sin def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0): def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1) q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim)) k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim)) return q_embed, k_embed def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" b, h, l, ch = x.shape if n_rep == 1: return x return ( x[:, :, None, :, :] .expand(b, h, n_rep, l, ch) .reshape(b, h * n_rep, l, ch) ) class PEGA(nn.Module): """ 位置编码门控注意力 """ def __init__(self, config: YConfig1_1): super().__init__() self.dropout = config.dropout # dropout rate self.hidden_size = config.hidden_size # 输入通道大小 self.num_heads = config.num_heads # 总注意力头数 self.pe_dim = config.pe_dim # 位置嵌入维度数 self.head_dim = config.head_dim # 每个头的维度 self.groups = config.groups # GQA头数 self.hidden_kv_dim = int(self.head_dim * self.num_heads // self.groups) self.gate_act = ACT2FN[config.hidden_act] self.delta_kv_only = False assert self.num_heads % self.groups == 0, "num_heads must be divisible by groups" # self.qpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False) # self.kpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False) # self.q = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) # self.kv = nn.Linear(self.hidden_size, self.hidden_kv_dim, bias=False) # equals to above self.qkv_list = [self.pe_dim, self.pe_dim, self.num_heads * self.head_dim, self.hidden_kv_dim] self.qkv = nn.Linear(self.hidden_size, sum(self.qkv_list), bias=False) self.o = nn.Linear(self.num_heads * self.hidden_kv_dim, self.hidden_size, bias=False) self.gate = nn.Linear(self.hidden_kv_dim, self.num_heads * self.hidden_kv_dim, bias=False) self.rsqrt_dim = 1.0 / math.sqrt(self.head_dim) def forward( self, x: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: b, l, _ = x.shape cos, sin = position_embeddings # [L, head_dim] # qpe = self.qpe(x) # [b, l, pe] # kpe = self.kpe(x) # [b, l, pe] # q = self.q(x) # [b, l, nope * hc] # kv = self.kv(x) # [b, l, ckv] qkv = self.qkv(x) qpe, kpe, q, kv = torch.split(qkv, self.qkv_list, dim=-1) # 应用 RoPE qpe, kpe = apply_rotary_pos_emb( qpe, kpe, cos[:l], sin[:l], ) deltakv = None if self.delta_kv_only: # 仅返回 delta kv deltakv = (kpe, kv) # kv_cache实现 if past_key_value is not None: kpe = torch.cat([past_key_value[0], kpe], dim=1) kv = torch.cat([past_key_value[1], kv], dim=1) past_kv = (kpe, kv) if use_cache else None _, l_all, _ = kv.shape dropout_p = self.dropout if self.training else 0.0 attn_mask = None if attention_mask is not None: attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1) attn_mask = attn_mask.bool() if attention_mask is not None else None qpe = qpe.reshape(b, l, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l, hc] kpe = kpe.reshape(b, l_all, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l_all, hc] q = q.reshape(b, l, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [b, nope, l, hc] nopek = kv.reshape(b, l_all, self.num_heads // self.groups, self.head_dim).permute(0, 2, 1, 3) # [b, g, l_all, hc] kv = kv.reshape(b, l_all, 1, self.hidden_kv_dim).permute(0, 2, 1, 3) # [b, 1, l_all, hc] if self.training: peo = nn.functional.scaled_dot_product_attention( qpe, kpe, kv, attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True ) nopeo = nn.functional.scaled_dot_product_attention( q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads), attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True ) else: # peo = nn.functional.scaled_dot_product_attention( # qpe, kpe, kv, # attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1 # ) # nopeo = nn.functional.scaled_dot_product_attention( # q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads), # attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1 # ) peo = self.sdpa_math(qpe, kpe, kv, attn_mask, 0.0) nopeo = self.sdpa_math(q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads), attn_mask, 0.0) peo = peo.permute(0, 2, 1, 3).reshape(b, l, -1) nopeo = nopeo.permute(0, 2, 1, 3).reshape(b, l, -1) gate = self.gate_act(self.gate(peo)) out = nopeo * gate out = self.o(out) out = nn.functional.dropout(out, p=self.dropout, training=self.training) return out, (deltakv if self.delta_kv_only else past_kv) def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0) -> (torch.Tensor, torch.Tensor): b, h, l, c = q.shape scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim casual_mask = torch.triu( torch.full((l, l), float("-inf"), device=scores.device), diagonal=1 ).unsqueeze(0).unsqueeze(0)# [1, 1, l, l] # 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all] casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all] scores += casual_mask if attn_mask is not None: attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9 scores = scores + attn_mask scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q) scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l] output = scores @ v return output def use_delta_kv_only(self, enable:bool=True): # 仅返回 delta kv,减少内存开销 self.delta_kv_only = enable class YFFN(nn.Module): """ shared up & down GeGLU, LoE (Lack of Expert) arc """ def __init__(self, config: YConfig1_1): super().__init__() self.act = ACT2FN[config.hidden_act] self.channels = config.hidden_size self.exp = config.exp self.c_up = int(self.channels * self.exp) self.ffn_shared = config.ffn_shared self.up = nn.Linear(self.channels, self.c_up, bias=False) self.down = nn.Linear(self.c_up, self.channels, bias=False) self.gates = nn.ModuleList([ nn.Linear(self.channels, self.c_up, bias=False) for _ in range(self.ffn_shared) ]) def forward(self, x:torch.Tensor, index:int, up_res:torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]: up = self.up(x) if up_res is not None: up += up_res gate = self.gates[index](x) gate = self.act(gate) up *= gate x = self.down(up) return x, up class YBlock(nn.Module): """ Groups of Transformer layers with shared FFN num layers is ffn_shared """ def __init__(self, config: YConfig1_1): super().__init__() self.attentions = nn.ModuleList([PEGA(config) for _ in range(config.ffn_shared)]) self.ffn = YFFN(config) self.attn_norms = nn.ModuleList([ RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared) ]) self.ffn_norms = nn.ModuleList([ RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared) ]) self.use_self_distill = config.self_distill def forward(self, x: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,# ffn_shard * kv cache use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None ): b, l, _ = x.shape kv_outs = [] ups = None cos_loss = None for i, (layer, kv_cache) in enumerate(zip(self.attentions, past_key_values)): x0 = x res = x x = self.attn_norms[i](x) x, kv_out = layer( x = x, position_embeddings=position_embeddings, past_key_value=kv_cache, attention_mask=attention_mask, use_cache=use_cache ) x += res res = x x = self.ffn_norms[i](x) x, ups = self.ffn(x, i, ups) x += res kv_outs.append(kv_out) if self.training and self.use_self_distill: xd = x.detach() # cosine loss c_loss = 1.0 - nn.functional.cosine_similarity(x0, xd, dim=-1).mean() cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss return x, kv_outs, cos_loss def delta_kv_only(self, delta_kv:bool=True): for i in range(len(self.attentions)): self.attentions[i].use_delta_kv_only(delta_kv) class YModel(nn.Module): def __init__(self, config: YConfig1_1): super().__init__() self.vocab_size = config.vocab_size self.num_layers = config.num_layers self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.dropout = config.dropout self.ffn_shared = config.ffn_shared assert self.num_layers % self.ffn_shared == 0, "num_layers must be divisible by ffn_shared" self.blks = nn.ModuleList([ YBlock(config) for _ in range(self.num_layers // self.ffn_shared) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.pe_dim, end=config.max_position_embeddings, theta=config.rope_theta) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, **kwargs ): batch_size, seq_length = input_ids.shape past_key_values = past_key_values or [None] * self.num_layers start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 x = self.embed_tokens(input_ids) x = nn.functional.dropout(x, p=self.dropout, training=self.training) position_embeddings = ( self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length] ) presents = [] cos_loss = None for layer_idx, block in enumerate(self.blks): past_key_value = past_key_values[self.ffn_shared * layer_idx: self.ffn_shared * (layer_idx + 1)] x, present, c_loss = block( x = x, position_embeddings = position_embeddings, past_key_values=past_key_value, use_cache=use_cache, attention_mask=attention_mask ) presents.extend(present) cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss x = self.norm(x) return x, presents, (cos_loss / self.num_layers if cos_loss is not None else None) def delta_kv_only(self, delta_kv:bool=True): for i in range(len(self.blks)): self.blks[i].delta_kv_only(delta_kv) class YForCausalLM1_1(PreTrainedModel, GenerationMixin): config_class = YConfig1_1 def __init__(self, config: YConfig1_1 = None): self.config = config or YConfig1_1() super().__init__(self.config) self.model = YModel(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.model.embed_tokens.weight = self.lm_head.weight self.OUT = CausalLMOutputWithPast() def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, **args): h, past_kvs, cos_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, **args ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(h[:, slice_indices, :]) self.OUT.__setitem__('last_hidden_state', h) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', 0.0) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('dist_loss', cos_loss) return self.OUT def delta_kv_only(self, delta_kv:bool=True): self.model.delta_kv_only(delta_kv)