File size: 5,070 Bytes
e6291f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
from tokenizers.models import BPE
from tokenizers import Tokenizer
import torch.nn as nn
import torch
import os.path

class ZZJRabbitConfig(PretrainedConfig):
    model_type = "zzjrabbit"

    def __init__(self, num_layers: int = 6, vocab_size: int = 10000, hidden_size: int = 256, **kwargs):
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        super().__init__(**kwargs)

class ZZJRabbitLayer(nn.Module):
    def __init__(self, config: ZZJRabbitConfig):
        super().__init__()
        self.attn = nn.MultiheadAttention(config.hidden_size, 8, 0.1, batch_first=True)
        self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.activate = nn.ReLU()
        self.norm = nn.RMSNorm(config.hidden_size)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        key_padding_mask = None
        attn_mask = None
        if self.training:
            attn_mask = torch.gt(torch.triu(torch.ones(q.size(-2), q.size(-2), device=q.device), 1), 0)
        if attention_mask is not None:
            key_padding_mask = torch.lt(attention_mask, 1)
        attn = self.attn(
            q, k, v,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            is_causal=True
        )[0]
        q = self.norm(q + attn)
        o = self.l1(q)
        o = self.activate(o)
        o = self.l2(o)
        return self.norm(q + o)

class ZZJRabbitModel(PreTrainedModel):
    config_class = ZZJRabbitConfig

    def __init__(self, config: ZZJRabbitConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config
        self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([ZZJRabbitLayer(config) for _ in range(config.num_layers)])

    def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs):
        emb = self.emb(input_ids)
        res = emb
        for l in self.layers:
            res = l(res, res, res, attention_mask)
        if not return_dict:
            return (res,)
        else:
            return BaseModelOutput(res)


class ZZJRabbitModelForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = ZZJRabbitConfig
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.model = ZZJRabbitModel(config, **kwargs)
        self.l = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids: torch.Tensor, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs):
        # print(input_ids, return_dict, labels, attention_mask, logits_to_keep, kwargs)
        hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
        logits = self.l(hidden[:, slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep, :])
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            print(loss)
        if not return_dict:
            return (loss, logits) if labels is not None else (logits,)
        else:
            return CausalLMOutput(logits=logits, loss=loss) if labels is not None else CausalLMOutput(logits=logits)

    @classmethod
    def can_generate(cls):
        return True

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

class ZZJRabbitTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"tokenizers_file": "tokenizer.json"}

    def __init__(self, tokenizers_file, **kwargs):
        self.internal = Tokenizer.from_file(tokenizers_file)
        super().__init__(**kwargs)

    def get_vocab(self):
        return {self.internal.id_to_token(i): i for i in range(self.vocab_size)}

    def tokenize(self, text, **kwargs):
        return self.internal.encode(text).tokens

    def convert_tokens_to_ids(self, tokens):
        return self.internal.token_to_id(tokens) if isinstance(tokens, str) else [self.internal.token_to_id(t) for t in tokens]

    def decode(self, tokens, skip_special_tokens=True, **kwargs):
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.tolist()
        return self.internal.decode(tokens, skip_special_tokens=skip_special_tokens)

    @property
    def vocab_size(self):
        return self.internal.get_vocab_size()

    def save_vocabulary(self, path, *args, **kwargs) -> tuple[str]:
        p = os.path.join(path, "tokenizer.json")
        self.internal.save(p)
        return (p,)