File size: 3,443 Bytes
10b5bb1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
import torch.nn as nn
import torch
class MSOTConfig(PretrainedConfig):
model_type = "msot"
def __init__(self, vocab_size=128, hidden_size=16, **kwargs):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
super().__init__(**kwargs)
class MSOTModel(PreTrainedModel):
config_class = MSOTConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
self.l3 = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, input_ids, return_dict = None, **kwargs):
hidden = self.emb(input_ids)
a = self.l1(hidden)
b = self.l2(hidden).transpose(-2, -1)
c = self.l3(hidden)
res = a @ b @ c
# print("input:", input_ids)
# print("output:", res)
if not return_dict:
return (res,)
else:
return BaseModelOutput(res)
class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MSOTConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.model = MSOTModel(config, **kwargs)
self.l = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids, return_dict = None, labels = None, **kwargs):
hidden = self.model(input_ids)[0]
res = self.l(hidden)
if labels is not None:
loss = nn.functional.cross_entropy(res[:, :-1, :].contiguous().view(-1, self.model.config.vocab_size), labels[:, 1:].contiguous().view(-1))
print(loss)
if not return_dict:
return (loss, res) if labels is not None else (res,)
else:
return CausalLMOutput(logits=res, loss=loss) if labels is not None else CausalLMOutput(logits=res)
def can_generate(self):
return True
def prepare_inputs_for_generation(self, input_ids, attention_mask = None, **kwargs):
return {"input_ids": input_ids}
class MSOTTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return {chr(i): i for i in range(65536)}
def _tokenize(self, text):
return [c if ord(c) < 65536 else 0 for c in text]
def _convert_token_to_id(self, token):
return ord(token)
def _convert_id_to_token(self, id):
return chr(id)
@property
def vocab_size(self):
return 65536
def save_vocabulary(self, *args, **kwargs):
return ()
def gen128(model, input):
tokens = torch.tensor([list(bytes(input,"ascii"))])
res = list(model.generate(tokens, max_new_tokens=50)[0])
return bytes(res).decode("utf-8")
def gen65536(model, input):
tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]])
res = list(model.generate(tokens, max_new_tokens=50)[0])
return "".join([chr(o) for o in res])
if __name__ == "__main__":
MSOTConfig.register_for_auto_class()
MSOTModel.register_for_auto_class("AutoModel")
MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM")
MSOTTokenizer.register_for_auto_class()
|