from typing import Optional, Union from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput from tokenizers.models import BPE from tokenizers import Tokenizer import torch.nn as nn import torch import os.path class ZZJRabbitConfig(PretrainedConfig): model_type = "zzjrabbit" def __init__(self, num_layers: int = 6, vocab_size: int = 10000, hidden_size: int = 256, **kwargs): self.num_layers = num_layers self.vocab_size = vocab_size self.hidden_size = hidden_size super().__init__(**kwargs) class ZZJRabbitLayer(nn.Module): def __init__(self, config: ZZJRabbitConfig): super().__init__() self.attn = nn.MultiheadAttention(config.hidden_size, 8, 0.1, batch_first=True) self.l1 = nn.Linear(config.hidden_size, config.hidden_size) self.l2 = nn.Linear(config.hidden_size, config.hidden_size) self.activate = nn.ReLU() self.norm = nn.RMSNorm(config.hidden_size) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: key_padding_mask = None attn_mask = None if self.training: attn_mask = torch.gt(torch.triu(torch.ones(q.size(-2), q.size(-2), device=q.device), 1), 0) if attention_mask is not None: key_padding_mask = torch.lt(attention_mask, 1) attn = self.attn( q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, is_causal=True )[0] q = self.norm(q + attn) o = self.l1(q) o = self.activate(o) o = self.l2(o) return self.norm(q + o) class ZZJRabbitModel(PreTrainedModel): config_class = ZZJRabbitConfig def __init__(self, config: ZZJRabbitConfig, **kwargs): super().__init__(config, **kwargs) self.config = config self.emb = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ZZJRabbitLayer(config) for _ in range(config.num_layers)]) def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs): emb = self.emb(input_ids) res = emb for l in self.layers: res = l(res, res, res, attention_mask) if not return_dict: return (res,) else: return BaseModelOutput(res) class ZZJRabbitModelForCausalLM(PreTrainedModel, GenerationMixin): config_class = ZZJRabbitConfig def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.model = ZZJRabbitModel(config, **kwargs) self.l = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs): # print(input_ids, return_dict, labels, attention_mask, logits_to_keep, kwargs) hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] logits = self.l(hidden[:, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :]) if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) print(loss) if not return_dict: return (loss, logits) if labels is not None else (logits,) else: return CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits) @classmethod def can_generate(cls): return True def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class ZZJRabbitTokenizer(PreTrainedTokenizer): vocab_files_names = {"tokenizers_file": "tokenizer.json"} def __init__(self, tokenizers_file, **kwargs): self.internal = Tokenizer.from_file(tokenizers_file) super().__init__(**kwargs) def get_vocab(self): return {self.internal.id_to_token(i): i for i in range(self.vocab_size)} def tokenize(self, text, **kwargs): return self.internal.encode(text).tokens def convert_tokens_to_ids(self, tokens): return self.internal.token_to_id(tokens) if isinstance(tokens, str) else [self.internal.token_to_id(t) for t in tokens] def decode(self, tokens, skip_special_tokens=True, **kwargs): if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens) @property def vocab_size(self): return self.internal.get_vocab_size() def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]: p = os.path.join(path, "tokenizer.json") self.internal.save(p) return (p,)