| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from tokenizers import Tokenizer, decoders, pre_tokenizers |
| from tokenizers.models import BPE |
| from transformers import ( |
| GenerationMixin, |
| PreTrainedConfig, |
| PreTrainedModel, |
| TokenizersBackend, |
| ) |
| from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput |
|
|
|
|
| class ZZJRabbit3Config(PreTrainedConfig): |
| model_type = "zzjrabbit3" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 100000, |
| hidden_size: int = 1024, |
| num_hidden_layers: int = 12, |
| num_attention_heads: int = 8, |
| attention_dropout: float | int = 0.0, |
| pad_token_id: int | None = None, |
| eos_token_id: int | list[int] | None = None, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.attention_dropout = attention_dropout |
| self.pad_token_id = pad_token_id |
| self.eos_token_id = eos_token_id |
| super().__init__(**kwargs) |
|
|
|
|
| class ZZJRabbit3RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=2048, base=10000): |
| """ |
| Rotary Embedding 模块 |
| |
| Args: |
| dim: 每个 token embedding 的维度 |
| max_position_embeddings: 最大位置数 |
| base: rotary embedding 的频率基底 |
| """ |
| super().__init__() |
| self.dim = dim |
| self.base = base |
|
|
| |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| |
| t = torch.arange(max_position_embeddings, dtype=torch.float32) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| self.register_buffer("cos_cached", freqs.cos()) |
| self.register_buffer("sin_cached", freqs.sin()) |
|
|
| def forward(self, position_ids): |
| """ |
| position_ids: (batch_size, seq_len) |
| 返回: |
| cos: (batch_size, seq_len, dim) |
| sin: (batch_size, seq_len, dim) |
| """ |
| |
| cos = self.cos_cached[position_ids] |
| sin = self.sin_cached[position_ids] |
|
|
| |
| |
| cos = torch.stack([cos, cos], dim=-1).flatten(-2) |
| sin = torch.stack([sin, sin], dim=-1).flatten(-2) |
| return cos, sin |
|
|
|
|
| def rotate_half(x): |
| """[-x2, x1]""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, sin, cos): |
| cos = cos.unsqueeze(1) |
| sin = sin.unsqueeze(1) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
| return q_embed, k_embed |
|
|
|
|
| class ZZJRabbit3Attention(nn.Module): |
| def __init__(self, config: ZZJRabbit3Config): |
| super().__init__() |
| self.config = config |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| key_padding_mask: Optional[torch.BoolTensor] = None, |
| attn_mask: Optional[torch.BoolTensor] = None, |
| ) -> torch.Tensor: |
| batch_size = x.size(0) |
| Q = ( |
| self.q_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| K = ( |
| self.k_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| V = ( |
| self.v_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| cos, sin = position_embeddings |
| Q, K = apply_rotary_pos_emb(Q, K, sin.to(Q.dtype), cos.to(Q.dtype)) |
| scores = torch.matmul(Q, K.transpose(-2, -1)) * (self.head_dim**-0.5) |
| if key_padding_mask is not None: |
| scores = scores.masked_fill( |
| key_padding_mask.view(batch_size, 1, 1, -1), float("-inf") |
| ) |
| if attn_mask is not None: |
| scores = scores.masked_fill(attn_mask, float("-inf")) |
| attn_weights = nn.functional.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
| context = torch.matmul(attn_weights, V) |
| context = context.transpose(1, 2).contiguous() |
| context = context.view(batch_size, -1, self.config.hidden_size) |
| return self.out_proj(context) |
|
|
|
|
| class ZZJRabbit3Layer(nn.Module): |
| def __init__(self, config: ZZJRabbit3Config): |
| super().__init__() |
| self.attn = ZZJRabbit3Attention(config) |
| self.l1 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.l2 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.activate = nn.ReLU() |
| self.norm = nn.RMSNorm(config.hidden_size) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| postition_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| key_padding_mask = None |
| attn_mask = torch.gt( |
| torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0 |
| ) |
| if attention_mask is not None: |
| key_padding_mask = torch.lt(attention_mask, 1) |
| attn = self.attn( |
| x, |
| postition_embeddings, |
| key_padding_mask=key_padding_mask, |
| attn_mask=attn_mask, |
| ) |
| x = self.norm(x + attn) |
| o = self.l1(x) |
| o = self.activate(o) |
| o = self.l2(o) |
| return self.norm(x + o) |
|
|
|
|
| class ZZJRabbit3Model(PreTrainedModel): |
| config_class = ZZJRabbit3Config |
|
|
| def __init__(self, config: ZZJRabbit3Config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.config = config |
| self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.rotary_emb = ZZJRabbit3RotaryEmbedding( |
| config.hidden_size // config.num_attention_heads |
| ) |
| self.layers = nn.ModuleList( |
| [ZZJRabbit3Layer(config) for _ in range(config.num_hidden_layers)] |
| ) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| return_dict: Optional[bool] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple | BaseModelOutput: |
| res = self.embedding(input_ids) |
| batch_size, seq_len = input_ids.shape |
| position_ids = ( |
| torch.arange(seq_len, device=input_ids.device) |
| .unsqueeze(0) |
| .expand(batch_size, -1) |
| ) |
| position_embeddings = self.rotary_emb(position_ids) |
| for layer in self.layers: |
| res = layer(res, position_embeddings, attention_mask) |
| if not return_dict: |
| return (res,) |
| else: |
| return BaseModelOutput(res) |
|
|
|
|
| class ZZJRabbit3ForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = ZZJRabbit3Config |
|
|
| def __init__(self, config: ZZJRabbit3Config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.model = ZZJRabbit3Model(config, **kwargs) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| return_dict: Optional[bool] = None, |
| labels: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs, |
| ) -> tuple | CausalLMOutput: |
| hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] |
| logits = self.lm_head( |
| hidden[ |
| :, |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep, |
| :, |
| ] |
| ) |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, |
| labels=labels, |
| vocab_size=self.config.vocab_size, |
| **kwargs, |
| ) |
| if not return_dict: |
| return (loss, logits) if labels is not None else (logits,) |
| else: |
| return ( |
| CausalLMOutput(logits=logits, loss=loss) |
| if labels is not None |
| else CausalLMOutput(logits=logits) |
| ) |
|
|
| @classmethod |
| def can_generate(cls): |
| return True |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
|
|
| class ZZJRabbit3Tokenizer(TokenizersBackend): |
| model = BPE |
|
|
| def __init__( |
| self, |
| vocab=None, |
| merges=None, |
| unk_token="<eos>", |
| eos_token="<eos>", |
| pad_token="<eos>", |
| **kwargs, |
| ): |
| self._vocab = vocab or { |
| "<eos>": 0, |
| } |
| self._merges = merges or [] |
|
|
| self._tokenizer = Tokenizer( |
| BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True) |
| ) |
| self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| self._tokenizer.decoder = decoders.ByteLevel() |
|
|
| super().__init__( |
| unk_token=unk_token, eos_token=eos_token, pad_token=pad_token, **kwargs |
| ) |
|
|