zzjrabbit2 / zzjrabbit2.py
xiaoyewuz-Ruster's picture
Upload TextGenerationPipeline
68951b5 verified
from typing import Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
from tokenizers import Tokenizer
import torch.nn as nn
import torch
import os.path
import math
class ZZJRabbit2Config(PretrainedConfig):
model_type = "zzjrabbit2"
def __init__(self, num_layers: int = 12, num_attention_heads: int = 8, vocab_size: int = 10000, hidden_size: int = 1024, **kwargs):
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.vocab_size = vocab_size
self.hidden_size = hidden_size
assert hidden_size % num_attention_heads == 0
super().__init__(**kwargs)
class ZZJRabbit2PE(nn.Module):
def __init__(self, hidden_size: int, max_len: int = 32768):
super().__init__()
pe = torch.zeros(max_len, hidden_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor):
return x + self.pe[:x.size(0), :]
class ZZJRabbit2Attention(nn.Module):
def __init__(self, config: ZZJRabbit2Config):
super().__init__()
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.BoolTensor] = None, attn_mask: Optional[torch.BoolTensor] = None):
batch_size = x.size(0)
Q = self.q_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if key_padding_mask is not None:
scores = scores.masked_fill(key_padding_mask.view(batch_size, 1, 1, -1), float("-inf"))
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float("-inf"))
attn_weights = nn.functional.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, -1, self.config.hidden_size)
return self.out_proj(context)
class ZZJRabbit2Layer(nn.Module):
def __init__(self, config: ZZJRabbit2Config):
super().__init__()
self.attn = ZZJRabbit2Attention(config)
self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
self.activate = nn.ReLU()
self.norm = nn.RMSNorm(config.hidden_size)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
key_padding_mask = None
attn_mask = None
if self.training:
attn_mask = torch.gt(torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0)
if attention_mask is not None:
key_padding_mask = torch.lt(attention_mask, 1)
attn = self.attn(
x,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)[0]
x = self.norm(x + attn)
o = self.l1(x)
o = self.activate(o)
o = self.l2(o)
return self.norm(x + o)
class ZZJRabbit2Model(PreTrainedModel):
config_class = ZZJRabbit2Config
def __init__(self, config: ZZJRabbit2Config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pe = ZZJRabbit2PE(config.hidden_size)
self.layers = nn.ModuleList([ZZJRabbit2Layer(config) for _ in range(config.num_layers)])
def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs):
res = self.emb(input_ids)
res = self.pe(res)
res = res
for l in self.layers:
res = l(res, attention_mask)
if not return_dict:
return (res,)
else:
return BaseModelOutput(res)
class ZZJRabbit2ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ZZJRabbit2Config
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.model = ZZJRabbit2Model(config, **kwargs)
self.l = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs):
# print(input_ids, return_dict, labels, attention_mask, logits_to_keep, kwargs)
hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
logits = self.l(hidden[:, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :])
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
print(loss)
if not return_dict:
return (loss, logits) if labels is not None else (logits,)
else:
return CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits)
@classmethod
def can_generate(cls):
return True
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class ZZJRabbit2Tokenizer(PreTrainedTokenizer):
vocab_files_names = {"tokenizers_file": "tokenizer.json"}
def __init__(self, tokenizers_file, **kwargs):
self.internal = Tokenizer.from_file(tokenizers_file)
super().__init__(**kwargs)
def get_vocab(self):
return {self.internal.id_to_token(i): i for i in range(self.vocab_size)}
def tokenize(self, text, **kwargs):
return self.internal.encode(text).tokens
def convert_tokens_to_ids(self, tokens):
return self.internal.token_to_id(tokens) if isinstance(tokens, str) else [self.internal.token_to_id(t) for t in tokens]
def decode(self, tokens, skip_special_tokens=True, **kwargs):
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens)
@property
def vocab_size(self):
return self.internal.get_vocab_size()
def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]:
p = os.path.join(path, "tokenizer.json")
self.internal.save(p)
return (p,)