| import math |
| import os.path |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from tokenizers import Tokenizer |
| from transformers import ( |
| GenerationMixin, |
| PretrainedConfig, |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| ) |
| from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput |
|
|
|
|
| class ZZJRabbit22Config(PretrainedConfig): |
| model_type = "zzjrabbit22" |
|
|
| def __init__( |
| self, |
| num_layers: int = 12, |
| num_attention_heads: int = 8, |
| vocab_size: int = 10000, |
| hidden_size: int = 1024, |
| **kwargs, |
| ): |
| self.num_layers = num_layers |
| self.num_attention_heads = num_attention_heads |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| assert hidden_size % num_attention_heads == 0 |
| super().__init__(**kwargs) |
|
|
|
|
| class ZZJRabbit22PE(nn.Module): |
| def __init__(self, hidden_size: int, max_len: int = 32768): |
| super().__init__() |
| pe = torch.zeros(max_len, hidden_size) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x: torch.Tensor): |
| return x + self.pe[: x.size(0), :] |
|
|
|
|
| class ZZJRabbit22Attention(nn.Module): |
| def __init__(self, config: ZZJRabbit22Config): |
| super().__init__() |
| self.config = config |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| key_padding_mask: Optional[torch.BoolTensor] = None, |
| attn_mask: Optional[torch.BoolTensor] = None, |
| ): |
| batch_size = x.size(0) |
| Q = ( |
| self.q_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| K = ( |
| self.k_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| V = ( |
| self.v_proj(x) |
| .view(batch_size, -1, self.config.num_attention_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt( |
| torch.tensor(self.head_dim, dtype=torch.float32) |
| ) |
| if key_padding_mask is not None: |
| scores = scores.masked_fill( |
| key_padding_mask.view(batch_size, 1, 1, -1), float("-inf") |
| ) |
| if attn_mask is not None: |
| scores = scores.masked_fill(attn_mask, float("-inf")) |
| attn_weights = nn.functional.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
| context = torch.matmul(attn_weights, V) |
| context = context.transpose(1, 2).contiguous() |
| context = context.view(batch_size, -1, self.config.hidden_size) |
| return self.out_proj(context) |
|
|
|
|
| class ZZJRabbit22Layer(nn.Module): |
| def __init__(self, config: ZZJRabbit22Config): |
| super().__init__() |
| self.attn = ZZJRabbit22Attention(config) |
| self.l1 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.l2 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.activate = nn.ReLU() |
| self.norm = nn.RMSNorm(config.hidden_size) |
|
|
| def forward( |
| self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| key_padding_mask = None |
| attn_mask = None |
| if self.training: |
| attn_mask = torch.gt( |
| torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0 |
| ) |
| if attention_mask is not None: |
| key_padding_mask = torch.lt(attention_mask, 1) |
| attn = self.attn( |
| x, |
| key_padding_mask=key_padding_mask, |
| attn_mask=attn_mask, |
| )[0] |
| x = self.norm(x + attn) |
| o = self.l1(x) |
| o = self.activate(o) |
| o = self.l2(o) |
| return self.norm(x + o) |
|
|
|
|
| class ZZJRabbit22Model(PreTrainedModel): |
| config_class = ZZJRabbit22Config |
|
|
| def __init__(self, config: ZZJRabbit22Config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.config = config |
| self.emb = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.pe = ZZJRabbit22PE(config.hidden_size) |
| self.layers = nn.ModuleList( |
| [ZZJRabbit22Layer(config) for _ in range(config.num_layers)] |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| return_dict: Optional[bool] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| res = self.emb(input_ids) |
| res = self.pe(res) |
| res = res |
| for l in self.layers: |
| res = l(res, attention_mask) |
| if not return_dict: |
| return (res,) |
| else: |
| return BaseModelOutput(res) |
|
|
|
|
| class ZZJRabbit22ForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = ZZJRabbit22Config |
|
|
| def __init__(self, config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.model = ZZJRabbit22Model(config, **kwargs) |
| self.l = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| return_dict: Optional[bool] = None, |
| labels: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs, |
| ): |
| |
| hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] |
| logits = self.l( |
| hidden[ |
| :, |
| slice(-logits_to_keep, None) |
| if isinstance(logits_to_keep, int) |
| else logits_to_keep, |
| :, |
| ] |
| ) |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, |
| labels=labels, |
| vocab_size=self.config.vocab_size, |
| **kwargs, |
| ) |
| print(loss) |
| if not return_dict: |
| return (loss, logits) if labels is not None else (logits,) |
| else: |
| return ( |
| CausalLMOutput(logits=logits, loss=loss) |
| if labels is not None |
| else CausalLMOutput(logits=logits) |
| ) |
|
|
| @classmethod |
| def can_generate(cls): |
| return True |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
|
|
| class ZZJRabbit22Tokenizer(PreTrainedTokenizer): |
| vocab_files_names = {"tokenizers_file": "tokenizer.json"} |
|
|
| def __init__(self, tokenizers_file, **kwargs): |
| self.internal = Tokenizer.from_file(tokenizers_file) |
| super().__init__(**kwargs) |
|
|
| def get_vocab(self): |
| return {self.internal.id_to_token(i): i for i in range(self.vocab_size)} |
|
|
| def tokenize(self, text, **kwargs): |
| return self.internal.encode(text).tokens |
|
|
| def convert_tokens_to_ids(self, tokens): |
| return ( |
| self.internal.token_to_id(tokens) |
| if isinstance(tokens, str) |
| else [self.internal.token_to_id(t) for t in tokens] |
| ) |
|
|
| def decode(self, tokens, skip_special_tokens=True, **kwargs): |
| if isinstance(tokens, torch.Tensor): |
| tokens = tokens.tolist() |
| return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens) |
|
|
| @property |
| def vocab_size(self): |
| return self.internal.get_vocab_size() |
|
|
| def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]: |
| p = os.path.join(path, "tokenizer.json") |
| self.internal.save(p) |
| return (p,) |
|
|