File size: 7,519 Bytes
68951b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from typing import Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
from tokenizers import Tokenizer
import torch.nn as nn
import torch
import os.path
import math
class ZZJRabbit2Config(PretrainedConfig):
model_type = "zzjrabbit2"
def __init__(self, num_layers: int = 12, num_attention_heads: int = 8, vocab_size: int = 10000, hidden_size: int = 1024, **kwargs):
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.vocab_size = vocab_size
self.hidden_size = hidden_size
assert hidden_size % num_attention_heads == 0
super().__init__(**kwargs)
class ZZJRabbit2PE(nn.Module):
def __init__(self, hidden_size: int, max_len: int = 32768):
super().__init__()
pe = torch.zeros(max_len, hidden_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor):
return x + self.pe[:x.size(0), :]
class ZZJRabbit2Attention(nn.Module):
def __init__(self, config: ZZJRabbit2Config):
super().__init__()
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.BoolTensor] = None, attn_mask: Optional[torch.BoolTensor] = None):
batch_size = x.size(0)
Q = self.q_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if key_padding_mask is not None:
scores = scores.masked_fill(key_padding_mask.view(batch_size, 1, 1, -1), float("-inf"))
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float("-inf"))
attn_weights = nn.functional.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, -1, self.config.hidden_size)
return self.out_proj(context)
class ZZJRabbit2Layer(nn.Module):
def __init__(self, config: ZZJRabbit2Config):
super().__init__()
self.attn = ZZJRabbit2Attention(config)
self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
self.activate = nn.ReLU()
self.norm = nn.RMSNorm(config.hidden_size)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
key_padding_mask = None
attn_mask = None
if self.training:
attn_mask = torch.gt(torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0)
if attention_mask is not None:
key_padding_mask = torch.lt(attention_mask, 1)
attn = self.attn(
x,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)[0]
x = self.norm(x + attn)
o = self.l1(x)
o = self.activate(o)
o = self.l2(o)
return self.norm(x + o)
class ZZJRabbit2Model(PreTrainedModel):
config_class = ZZJRabbit2Config
def __init__(self, config: ZZJRabbit2Config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pe = ZZJRabbit2PE(config.hidden_size)
self.layers = nn.ModuleList([ZZJRabbit2Layer(config) for _ in range(config.num_layers)])
def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs):
res = self.emb(input_ids)
res = self.pe(res)
res = res
for l in self.layers:
res = l(res, attention_mask)
if not return_dict:
return (res,)
else:
return BaseModelOutput(res)
class ZZJRabbit2ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ZZJRabbit2Config
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.model = ZZJRabbit2Model(config, **kwargs)
self.l = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs):
# print(input_ids, return_dict, labels, attention_mask, logits_to_keep, kwargs)
hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
logits = self.l(hidden[:, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :])
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
print(loss)
if not return_dict:
return (loss, logits) if labels is not None else (logits,)
else:
return CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits)
@classmethod
def can_generate(cls):
return True
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class ZZJRabbit2Tokenizer(PreTrainedTokenizer):
vocab_files_names = {"tokenizers_file": "tokenizer.json"}
def __init__(self, tokenizers_file, **kwargs):
self.internal = Tokenizer.from_file(tokenizers_file)
super().__init__(**kwargs)
def get_vocab(self):
return {self.internal.id_to_token(i): i for i in range(self.vocab_size)}
def tokenize(self, text, **kwargs):
return self.internal.encode(text).tokens
def convert_tokens_to_ids(self, tokens):
return self.internal.token_to_id(tokens) if isinstance(tokens, str) else [self.internal.token_to_id(t) for t in tokens]
def decode(self, tokens, skip_special_tokens=True, **kwargs):
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens)
@property
def vocab_size(self):
return self.internal.get_vocab_size()
def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]:
p = os.path.join(path, "tokenizer.json")
self.internal.save(p)
return (p,)
|