from typing import Optional, Union import torch import torch.nn as nn from tokenizers import Tokenizer, decoders, pre_tokenizers from tokenizers.models import BPE from transformers import ( GenerationMixin, PreTrainedConfig, PreTrainedModel, TokenizersBackend, ) from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput class ZZJRabbit3Config(PreTrainedConfig): model_type = "zzjrabbit3" def __init__( self, vocab_size: int = 100000, hidden_size: int = 1024, num_hidden_layers: int = 12, num_attention_heads: int = 8, attention_dropout: float | int = 0.0, pad_token_id: int | None = None, eos_token_id: int | list[int] | None = None, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.attention_dropout = attention_dropout self.pad_token_id = pad_token_id self.eos_token_id = eos_token_id super().__init__(**kwargs) class ZZJRabbit3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): """ Rotary Embedding 模块 Args: dim: 每个 token embedding 的维度 max_position_embeddings: 最大位置数 base: rotary embedding 的频率基底 """ super().__init__() self.dim = dim self.base = base # 生成频率向量 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) # 可选:预先计算 cos/sin t = torch.arange(max_position_embeddings, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) self.register_buffer("cos_cached", freqs.cos()) self.register_buffer("sin_cached", freqs.sin()) def forward(self, position_ids): """ position_ids: (batch_size, seq_len) 返回: cos: (batch_size, seq_len, dim) sin: (batch_size, seq_len, dim) """ # 从缓存中选取对应位置 cos = self.cos_cached[position_ids] # shape (batch, seq_len, dim/2) sin = self.sin_cached[position_ids] # 将维度对齐为 (dim) # cos/sin 当前 shape 为 (..., dim/2),重复到 dim cos = torch.stack([cos, cos], dim=-1).flatten(-2) sin = torch.stack([sin, sin], dim=-1).flatten(-2) return cos, sin def rotate_half(x): """[-x2, x1]""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, sin, cos): cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class ZZJRabbit3Attention(nn.Module): def __init__(self, config: ZZJRabbit3Config): super().__init__() self.config = config self.head_dim = config.hidden_size // config.num_attention_heads self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(0.1) def forward( self, x: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], key_padding_mask: Optional[torch.BoolTensor] = None, attn_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: batch_size = x.size(0) Q = ( self.q_proj(x) .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) .transpose(1, 2) ) K = ( self.k_proj(x) .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) .transpose(1, 2) ) V = ( self.v_proj(x) .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) .transpose(1, 2) ) cos, sin = position_embeddings Q, K = apply_rotary_pos_emb(Q, K, sin.to(Q.dtype), cos.to(Q.dtype)) scores = torch.matmul(Q, K.transpose(-2, -1)) * (self.head_dim**-0.5) if key_padding_mask is not None: scores = scores.masked_fill( key_padding_mask.view(batch_size, 1, 1, -1), float("-inf") ) if attn_mask is not None: scores = scores.masked_fill(attn_mask, float("-inf")) attn_weights = nn.functional.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.matmul(attn_weights, V) context = context.transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.config.hidden_size) return self.out_proj(context) class ZZJRabbit3Layer(nn.Module): def __init__(self, config: ZZJRabbit3Config): super().__init__() self.attn = ZZJRabbit3Attention(config) self.l1 = nn.Linear(config.hidden_size, config.hidden_size) self.l2 = nn.Linear(config.hidden_size, config.hidden_size) self.activate = nn.ReLU() self.norm = nn.RMSNorm(config.hidden_size) def forward( self, x: torch.Tensor, postition_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: key_padding_mask = None attn_mask = torch.gt( torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0 ) if attention_mask is not None: key_padding_mask = torch.lt(attention_mask, 1) attn = self.attn( x, postition_embeddings, key_padding_mask=key_padding_mask, attn_mask=attn_mask, ) x = self.norm(x + attn) o = self.l1(x) o = self.activate(o) o = self.l2(o) return self.norm(x + o) class ZZJRabbit3Model(PreTrainedModel): config_class = ZZJRabbit3Config def __init__(self, config: ZZJRabbit3Config, **kwargs): super().__init__(config, **kwargs) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.rotary_emb = ZZJRabbit3RotaryEmbedding( config.hidden_size // config.num_attention_heads ) self.layers = nn.ModuleList( [ZZJRabbit3Layer(config) for _ in range(config.num_hidden_layers)] ) self.post_init() def forward( self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple | BaseModelOutput: res = self.embedding(input_ids) batch_size, seq_len = input_ids.shape position_ids = ( torch.arange(seq_len, device=input_ids.device) .unsqueeze(0) .expand(batch_size, -1) ) position_embeddings = self.rotary_emb(position_ids) for layer in self.layers: res = layer(res, position_embeddings, attention_mask) if not return_dict: return (res,) else: return BaseModelOutput(res) class ZZJRabbit3ForCausalLM(PreTrainedModel, GenerationMixin): config_class = ZZJRabbit3Config def __init__(self, config: ZZJRabbit3Config, **kwargs): super().__init__(config, **kwargs) self.model = ZZJRabbit3Model(config, **kwargs) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.post_init() def forward( self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> tuple | CausalLMOutput: hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] logits = self.lm_head( hidden[ :, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :, ] ) if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: return (loss, logits) if labels is not None else (logits,) else: return ( CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits) ) @classmethod def can_generate(cls): return True def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class ZZJRabbit3Tokenizer(TokenizersBackend): model = BPE def __init__( self, vocab=None, merges=None, unk_token="", eos_token="", pad_token="", **kwargs, ): self._vocab = vocab or { "": 0, } self._merges = merges or [] self._tokenizer = Tokenizer( BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True) ) self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) self._tokenizer.decoder = decoders.ByteLevel() super().__init__( unk_token=unk_token, eos_token=eos_token, pad_token=pad_token, **kwargs )