File size: 7,519 Bytes
45b319a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from typing import Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
from tokenizers import Tokenizer
import torch.nn as nn
import torch
import os.path
import math

class ZZJRabbit2Config(PretrainedConfig):
    model_type = "zzjrabbit2"

    def __init__(self, num_layers: int = 12, num_attention_heads: int = 8, vocab_size: int = 10000, hidden_size: int = 1024, **kwargs):
        self.num_layers = num_layers
        self.num_attention_heads = num_attention_heads
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        assert hidden_size % num_attention_heads == 0
        super().__init__(**kwargs)

class ZZJRabbit2PE(nn.Module):
    def __init__(self, hidden_size: int, max_len: int = 32768):
        super().__init__()
        pe = torch.zeros(max_len, hidden_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor):
        return x + self.pe[:x.size(0), :]

class ZZJRabbit2Attention(nn.Module):
    def __init__(self, config: ZZJRabbit2Config):
        super().__init__()
        self.config = config
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.BoolTensor] = None, attn_mask: Optional[torch.BoolTensor] = None):
        batch_size = x.size(0)
        Q = self.q_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, -1, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if key_padding_mask is not None:
            scores = scores.masked_fill(key_padding_mask.view(batch_size, 1, 1, -1), float("-inf"))
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float("-inf"))
        attn_weights = nn.functional.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, -1, self.config.hidden_size)
        return self.out_proj(context)

class ZZJRabbit2Layer(nn.Module):
    def __init__(self, config: ZZJRabbit2Config):
        super().__init__()
        self.attn = ZZJRabbit2Attention(config)
        self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.activate = nn.ReLU()
        self.norm = nn.RMSNorm(config.hidden_size)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        key_padding_mask = None
        attn_mask = None
        if self.training:
            attn_mask = torch.gt(torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0)
        if attention_mask is not None:
            key_padding_mask = torch.lt(attention_mask, 1)
        attn = self.attn(
            x,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
        )[0]
        x = self.norm(x + attn)
        o = self.l1(x)
        o = self.activate(o)
        o = self.l2(o)
        return self.norm(x + o)

class ZZJRabbit2Model(PreTrainedModel):
    config_class = ZZJRabbit2Config

    def __init__(self, config: ZZJRabbit2Config, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config
        self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pe = ZZJRabbit2PE(config.hidden_size)
        self.layers = nn.ModuleList([ZZJRabbit2Layer(config) for _ in range(config.num_layers)])

    def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs):
        res = self.emb(input_ids)
        res = self.pe(res)
        res = res
        for l in self.layers:
            res = l(res, attention_mask)
        if not return_dict:
            return (res,)
        else:
            return BaseModelOutput(res)


class ZZJRabbit2ForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = ZZJRabbit2Config
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.model = ZZJRabbit2Model(config, **kwargs)
        self.l = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs):
        # print(input_ids, return_dict, labels, attention_mask, logits_to_keep, kwargs)
        hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
        logits = self.l(hidden[:, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :])
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            print(loss)
        if not return_dict:
            return (loss, logits) if labels is not None else (logits,)
        else:
            return CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits)

    @classmethod
    def can_generate(cls):
        return True

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

class ZZJRabbit2Tokenizer(PreTrainedTokenizer):
    vocab_files_names = {"tokenizers_file": "tokenizer.json"}

    def __init__(self, tokenizers_file, **kwargs):
        self.internal = Tokenizer.from_file(tokenizers_file)
        super().__init__(**kwargs)

    def get_vocab(self):
        return {self.internal.id_to_token(i): i for i in range(self.vocab_size)}

    def tokenize(self, text, **kwargs):
        return self.internal.encode(text).tokens

    def convert_tokens_to_ids(self, tokens):
        return self.internal.token_to_id(tokens) if isinstance(tokens, str) else [self.internal.token_to_id(t) for t in tokens]

    def decode(self, tokens, skip_special_tokens=True, **kwargs):
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.tolist()
        return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens)

    @property
    def vocab_size(self):
        return self.internal.get_vocab_size()

    def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]:
        p = os.path.join(path, "tokenizer.json")
        self.internal.save(p)
        return (p,)