import math import torch import torch.nn as nn from typing import Optional, Tuple, Union, List from transformers import PreTrainedModel, GenerationMixin from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.configuration_utils import PretrainedConfig class YConfig2(PretrainedConfig): model_type = "ynet2" def __init__( self, dropout: float = 0.1, bos_token_id: int = 1, eos_token_id: int = 2, hidden_act: str = 'gelu_pytorch_tanh',# silu 4.687 / gelu 4.662 / mish 4.695 / relu2 4.755 / laplace hidden_size: int = 768, num_layers: int = 9, max_position_embeddings: int = 8192, vocab_size: int = 6400, rms_norm_eps: float = 1e-8, rope_theta: int = 5e4,# 5e4 self_distill: bool = True, force_flash_attn=False, ### FFN ### intermediate_size: int = None, # 512 * 4 (full [4] / 256) = 2048 (2 ** 17) ### attn ### num_heads: int = 4, head_dim: int = 64, **kwargs ): super().__init__(**kwargs) self.dropout = dropout self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.hidden_act = hidden_act self.hidden_size = hidden_size self.num_layers = num_layers # 层数 self.max_position_embeddings = max_position_embeddings self.vocab_size = vocab_size self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.self_distill = self_distill self.force_flash_attn = force_flash_attn ### FFN ### self.intermediate_size = intermediate_size # FFN中间维度 ### attn ### self.num_heads = num_heads # q头数 self.head_dim = head_dim # 头维度 def scale_lvl(self, lvl:int=0): if lvl == 0: # normal settings [99.312m] self.num_layers = 16 self.hidden_size = 768 self.num_heads = 16 self.head_dim = 128 self.intermediate_size = 2048 elif lvl == -1: self.num_layers = 8 self.hidden_size = 512 # base = 4.662 16h/64d 30 self.num_heads = 8 # 2*heads 4.578/20.84 self.head_dim = 64 # 2*dim 4.576/22.8 self.intermediate_size = 1536 elif lvl == -2: self.num_layers = 4 self.hidden_size = 512 self.num_heads = 8 self.head_dim = 64 self.intermediate_size = 1024 else: raise ValueError(f"Invalid level: {lvl}") class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) output = output * self.weight.float() return output.type_as(x) def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 5e4): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) return freqs_cos, freqs_sin def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0): def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1) q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim)) k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim)) return q_embed, k_embed def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" b, h, l, ch = x.shape if n_rep == 1: return x return ( x[:, :, None, :, :] .expand(b, h, n_rep, l, ch) .reshape(b, h * n_rep, l, ch) ) class FFN(nn.Module): def __init__(self, config: YConfig2): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size or int(2.5 * config.hidden_size) self.gate_act = ACT2FN[config.hidden_act] self.up = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) # self.up = nn.Linear(self.hidden_size, self.intermediate_size) # self.gate = nn.Linear(self.hidden_size, self.intermediate_size) self.down = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x, g = self.up(x).chunk(2, dim=-1) # x, g = self.up(x), self.gate(x) x = self.gate_act(g) * x x = self.down(x) return x class PEGA2(nn.Module): def __init__(self, config: YConfig2): super().__init__() self.dropout = config.dropout # dropout rate self.hidden_size = config.hidden_size # 输入通道大小 self.num_heads = config.num_heads # 总注意力头数 self.head_dim = config.head_dim # 每个头的维度 self.gate_act = ACT2FN[config.hidden_act] self.delta_kv_only = False self.force_flash_attn = config.force_flash_attn assert self.num_heads % 2 == 0, "num_heads must be even." # 2d opt: fused 29.5/4.693 split: 28.7/4.791 # qpe, q self.qkv_list = [ self.num_heads // 2 * self.head_dim, # qpe self.num_heads // 2 * self.head_dim, # qnope self.head_dim, # kpe self.head_dim, # kv ] self.qkv = nn.Sequential( nn.Linear(self.hidden_size, self.head_dim, bias=False), nn.Linear(self.head_dim, sum(self.qkv_list), bias=False) ) # self.z = nn.Linear(self.hidden_size, self.head_dim, bias=False) # self.qpe = nn.Linear(self.head_dim, self.num_heads // 2 * self.head_dim, bias=False) # self.qnope = nn.Linear(self.head_dim, self.num_heads // 2 * self.head_dim, bias=False) # self.kpe = nn.Linear(self.head_dim, self.head_dim, bias=False) # self.kv = nn.Linear(self.head_dim, self.head_dim, bias=False) self.o = nn.Linear(self.head_dim // 2 * self.num_heads, self.hidden_size, bias=False) self.rsqrt_dim = 1.0 / math.sqrt(self.head_dim) # init 2k 4.693 --> 4.687 scale_lora = math.sqrt( (sum(self.qkv_list) + self.head_dim) * (self.head_dim + self.head_dim) / (2 * self.head_dim * (self.hidden_size + sum(self.qkv_list))) ) self.qkv[1].weight.data *= scale_lora def forward( self, x: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: cos, sin = position_embeddings # [L, head_dim] b, l, _ = x.shape # fused qkv = self.qkv(x) qpe, q, kpe, kv = torch.split(qkv, self.qkv_list, dim=-1)# [b, l, hd * h // 2] [b, l, hd] # z = self.z(x) # qpe, q, kpe, kv = ( # self.qpe(z), # self.qnope(z), # self.kpe(z), # self.kv(z) # ) # 应用 RoPE q = q.view(b, l, self.num_heads // 2, self.head_dim).permute(0, 2, 1, 3) # [b, l, h // 2, hd] qpe = qpe.view(b, l, self.num_heads // 2, self.head_dim).permute(0, 2, 1, 3)# [b, l, h // 2, hd] kv = kv.unsqueeze(1) # [b, 1, l, hd] kpe = kpe.unsqueeze(1) # [b, 1, l, hd] qpe, kpe = apply_rotary_pos_emb(qpe, kpe, cos[:l], sin[:l]) # 拼合 q = torch.cat([qpe, q], dim=1) # [b, h, l, hd] kv = torch.cat([kpe, kv], dim=1) # [b, 2, l, hd] deltakv = None if self.delta_kv_only: # 仅返回 delta kv deltakv = kv # kv_cache实现 if past_key_value is not None: kv = torch.cat([past_key_value, kv], dim=2) past_kv = kv if use_cache else None _, _, l_all, _ = kv.shape dropout_p = self.dropout if self.training else 0.0 attn_mask = None if attention_mask is not None: attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1) attn_mask = attn_mask.bool() if attention_mask is not None else None if self.training or self.force_flash_attn: o = nn.functional.scaled_dot_product_attention( q, repeat_kv(kv, self.num_heads // 2), repeat_kv(kv[:, 1:, :, :], self.num_heads), attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True ) else: o = self.sdpa_math( q, repeat_kv(kv, self.num_heads // 2), repeat_kv(kv[:, 1:, :, :], self.num_heads), attn_mask, 0.0 ) # o: [b, h, l, hc] # gate 2k4b peg: 5.169 nopeg: 5.179 +gate:5.210(4.622) ope, onope = o.permute(0, 2, 1, 3).chunk(2, dim=2) # [b, l, h // 2, hc] # o = onope * self.gate_act(ope) # [b, l, h // 2, hc] not stable o = ope * self.gate_act(onope) # [b, l, h // 2, hc] testing out = o.reshape(b, l, -1) out = self.o(out) out = nn.functional.dropout(out, p=self.dropout, training=self.training) return out, (deltakv if self.delta_kv_only else past_kv) def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0) -> torch.Tensor: b, h, l, c = q.shape scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim casual_mask = torch.triu( torch.full((l, l), float("-inf"), device=scores.device), diagonal=1 ).unsqueeze(0).unsqueeze(0)# [1, 1, l, l] # 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all] casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all] scores += casual_mask if attn_mask is not None: attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9 scores = scores + attn_mask scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q) scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l] output = scores @ v return output def use_delta_kv_only(self, enable:bool=True): # 仅返回 delta kv,减少内存开销 self.delta_kv_only = enable class Attn(nn.Module): def __init__(self, config: YConfig2): super().__init__() self.dropout = config.dropout # dropout rate self.hidden_size = config.hidden_size # 输入通道大小 self.num_heads = config.num_heads # 总注意力头数 self.head_dim = config.head_dim # 每个头的维度 self.gate_act = ACT2FN[config.hidden_act] self.delta_kv_only = False assert self.num_heads % 2 == 0, "num_heads must be even." ##### sparse ##### # qpe, q self.qkv_list = [ self.num_heads * self.head_dim, # q 2 * self.head_dim, # k 2 * self.head_dim, # v ] self.qkv = nn.Linear(self.hidden_size, sum(self.qkv_list), bias=False) self.o = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=False) def forward( self, x: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: cos, sin = position_embeddings # [L, head_dim] b, l, _ = x.shape # dense qkv = self.qkv(x) q, k, v = torch.split(qkv, self.qkv_list, dim=-1)# [b, l, hd * h // 2] [b, l, hd] # qpe, q, kpe, kv = ( # self.qpe(x), # self.qnope(x), # self.kpe(x), # self.kv(x) # ) # 应用 RoPE q = q.view(b, l, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [b, l, h // 2, hd] k = k.view(b, l, 2, self.head_dim).permute(0, 2, 1, 3) # [b, 2, l, hd] v = v.view(b, l, 2, self.head_dim).permute(0, 2, 1, 3) # [b, 2, l, hd] q, k = apply_rotary_pos_emb(q, k, cos[:l], sin[:l]) deltakv = None if self.delta_kv_only: # 仅返回 delta kv deltakv = None # kv_cache实现 if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=1) v = torch.cat([past_key_value[1], v], dim=1) past_kv = (k, v) if use_cache else None _, _, l_all, _ = k.shape dropout_p = self.dropout if self.training else 0.0 attn_mask = None if attention_mask is not None: attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1) attn_mask = attn_mask.bool() if attention_mask is not None else None if self.training: o = nn.functional.scaled_dot_product_attention( q, repeat_kv(k, self.num_heads//2), repeat_kv(v, self.num_heads//2), attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True ) else: o = self.sdpa_math( q, repeat_kv(k, self.num_heads // 2), repeat_kv(v, self.num_heads), attn_mask, 0.0 ) # o: [b, h, l, hc] out = o.permute(0, 2, 1, 3).reshape(b, l, -1) out = self.o(out) out = nn.functional.dropout(out, p=self.dropout, training=self.training) return out, (deltakv if self.delta_kv_only else past_kv) def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0) -> torch.Tensor: b, h, l, c = q.shape scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim casual_mask = torch.triu( torch.full((l, l), float("-inf"), device=scores.device), diagonal=1 ).unsqueeze(0).unsqueeze(0)# [1, 1, l, l] # 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all] casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all] scores += casual_mask if attn_mask is not None: attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9 scores = scores + attn_mask scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q) scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l] output = scores @ v return output def use_delta_kv_only(self, enable:bool=True): # 仅返回 delta kv,减少内存开销 self.delta_kv_only = enable class YBlock2(nn.Module): def __init__(self, config: YConfig2): super().__init__() self.attn = PEGA2(config) self.ffn = FFN(config) self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, x: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[torch.Tensor] = None, # ffn_shard * kv cache use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None ): # attention residual = x x = self.norm1(x) attn_out, past_kv = self.attn( x, position_embeddings, past_key_value=past_key_value, attention_mask=attention_mask, use_cache=use_cache, ) x = residual + attn_out # ffn residual = x x = self.norm2(x) moe_out = self.ffn(x) x = residual + moe_out return x, past_kv def use_delta_kv_only(self, enable:bool=True): self.attn.use_delta_kv_only(enable) class YModel2(nn.Module): def __init__(self, config: YConfig2): super().__init__() self.vocab_size = config.vocab_size self.num_layers = config.num_layers self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.dropout = config.dropout self.use_self_distill = config.self_distill self.layers = nn.ModuleList([ YBlock2(config) for _ in range(config.num_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, theta=config.rope_theta) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, use_cache: bool = False, **kwargs ): batch_size, seq_length = input_ids.shape past_key_values = past_key_values or [None] * self.num_layers start_pos = past_key_values[0].shape[-2] if past_key_values[0] is not None else 0 x = self.embed_tokens(input_ids) x = nn.functional.dropout(x, p=self.dropout, training=self.training) position_embeddings = ( self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length] ) presents = [] cos_loss = None for i, layer in enumerate(self.layers): x0 = x x, past_kv = layer( x=x, position_embeddings=position_embeddings, past_key_value=past_key_values[i], attention_mask=attention_mask, use_cache=use_cache ) if self.training and self.use_self_distill: xd = x.detach() # cosine loss c_loss = 1.0 - nn.functional.cosine_similarity(x0, xd, dim=-1).mean() cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss presents.append(past_kv) if cos_loss is not None: cos_loss = cos_loss / self.num_layers x = self.norm(x) return x, presents, cos_loss def delta_kv_only(self, delta_kv:bool=True): for layer in self.layers: layer.use_delta_kv_only(delta_kv) class YForCausalLM2(PreTrainedModel, GenerationMixin): config_class = YConfig2 def __init__(self, config: YConfig2 = None, **kwargs): self.config = config or YConfig2() super().__init__(self.config) self.model = YModel2(self.config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) self.model.embed_tokens.weight = self.lm_head.weight self.OUT = CausalLMOutputWithPast() if kwargs.get('dtype') is not None: dtype = kwargs['dtype'] m_dtype = torch.float32 if dtype == 'bfloat16': m_dtype = torch.bfloat16 elif dtype == 'float16': m_dtype = torch.float16 self.model.to(m_dtype) self.lm_head.to(m_dtype) def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, **args): h, past_kvs, cos_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, **args ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(h[:, slice_indices, :]) self.OUT.__setitem__('last_hidden_state', h) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('past_key_values', past_kvs) if self.config.self_distill: self.OUT.__setitem__('dist_loss', cos_loss) return self.OUT def delta_kv_only(self, delta_kv:bool=True): self.model.delta_kv_only(delta_kv)