File size: 3,443 Bytes
10b5bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
import torch.nn as nn
import torch

class MSOTConfig(PretrainedConfig):
    model_type = "msot"
    def __init__(self, vocab_size=128, hidden_size=16, **kwargs):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        super().__init__(**kwargs)

class MSOTModel(PreTrainedModel):
    config_class = MSOTConfig
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config
        self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l3 = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, input_ids, return_dict = None, **kwargs):
        hidden = self.emb(input_ids)
        a = self.l1(hidden)
        b = self.l2(hidden).transpose(-2, -1)
        c = self.l3(hidden)
        res = a @ b @ c
#        print("input:", input_ids)
#        print("output:", res)
        if not return_dict:
            return (res,)
        else:
            return BaseModelOutput(res)

class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = MSOTConfig
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.model = MSOTModel(config, **kwargs)
        self.l = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, return_dict = None, labels = None, **kwargs):
        hidden = self.model(input_ids)[0]
        res = self.l(hidden)
        if labels is not None:
            loss = nn.functional.cross_entropy(res[:, :-1, :].contiguous().view(-1, self.model.config.vocab_size), labels[:, 1:].contiguous().view(-1))
            print(loss)
        if not return_dict:
            return (loss, res) if labels is not None else (res,)
        else:
            return CausalLMOutput(logits=res, loss=loss) if labels is not None else CausalLMOutput(logits=res)

    def can_generate(self):
        return True

    def prepare_inputs_for_generation(self, input_ids, attention_mask = None, **kwargs):
        return {"input_ids": input_ids}



class MSOTTokenizer(PreTrainedTokenizer):
    def get_vocab(self):
        return {chr(i): i for i in range(65536)}

    def _tokenize(self, text):
        return [c if ord(c) < 65536 else 0 for c in text]

    def _convert_token_to_id(self, token):
        return ord(token)

    def _convert_id_to_token(self, id):
        return chr(id)

    @property
    def vocab_size(self):
        return 65536

    def save_vocabulary(self, *args, **kwargs):
        return ()

def gen128(model, input):
    tokens = torch.tensor([list(bytes(input,"ascii"))])
    res = list(model.generate(tokens, max_new_tokens=50)[0])
    return bytes(res).decode("utf-8")

def gen65536(model, input):
    tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]])
    res = list(model.generate(tokens, max_new_tokens=50)[0])
    return "".join([chr(o) for o in res])

if __name__ == "__main__":
    MSOTConfig.register_for_auto_class()
    MSOTModel.register_for_auto_class("AutoModel")
    MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM")
    MSOTTokenizer.register_for_auto_class()