| from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput |
| import torch.nn as nn |
| import torch |
|
|
| class MSOTConfig(PretrainedConfig): |
| model_type = "msot" |
| def __init__(self, vocab_size=128, hidden_size=16, **kwargs): |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| super().__init__(**kwargs) |
|
|
| class MSOTModel(PreTrainedModel): |
| config_class = MSOTConfig |
| def __init__(self, config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.config = config |
| self.emb = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.l1 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.l2 = nn.Linear(config.hidden_size, config.hidden_size) |
| self.l3 = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
| def forward(self, input_ids, return_dict = None, **kwargs): |
| hidden = self.emb(input_ids) |
| a = self.l1(hidden) |
| b = self.l2(hidden).transpose(-2, -1) |
| c = self.l3(hidden) |
| res = a @ b @ c |
| |
| |
| if not return_dict: |
| return (res,) |
| else: |
| return BaseModelOutput(res) |
|
|
| class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = MSOTConfig |
| def __init__(self, config, **kwargs): |
| super().__init__(config, **kwargs) |
| self.model = MSOTModel(config, **kwargs) |
| self.l = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
| def forward(self, input_ids, return_dict = None, labels = None, **kwargs): |
| hidden = self.model(input_ids)[0] |
| res = self.l(hidden) |
| if labels is not None: |
| loss = nn.functional.cross_entropy(res[:, :-1, :].contiguous().view(-1, self.model.config.vocab_size), labels[:, 1:].contiguous().view(-1)) |
| print(loss) |
| if not return_dict: |
| return (loss, res) if labels is not None else (res,) |
| else: |
| return CausalLMOutput(logits=res, loss=loss) if labels is not None else CausalLMOutput(logits=res) |
|
|
| def can_generate(self): |
| return True |
|
|
| def prepare_inputs_for_generation(self, input_ids, attention_mask = None, **kwargs): |
| return {"input_ids": input_ids} |
|
|
|
|
|
|
| class MSOTTokenizer(PreTrainedTokenizer): |
| def get_vocab(self): |
| return {chr(i): i for i in range(65536)} |
|
|
| def _tokenize(self, text): |
| return [c if ord(c) < 65536 else 0 for c in text] |
|
|
| def _convert_token_to_id(self, token): |
| return ord(token) |
|
|
| def _convert_id_to_token(self, id): |
| return chr(id) |
|
|
| @property |
| def vocab_size(self): |
| return 65536 |
|
|
| def save_vocabulary(self, *args, **kwargs): |
| return () |
|
|
| def gen128(model, input): |
| tokens = torch.tensor([list(bytes(input,"ascii"))]) |
| res = list(model.generate(tokens, max_new_tokens=50)[0]) |
| return bytes(res).decode("utf-8") |
|
|
| def gen65536(model, input): |
| tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]]) |
| res = list(model.generate(tokens, max_new_tokens=50)[0]) |
| return "".join([chr(o) for o in res]) |
|
|
| if __name__ == "__main__": |
| MSOTConfig.register_for_auto_class() |
| MSOTModel.register_for_auto_class("AutoModel") |
| MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
| MSOTTokenizer.register_for_auto_class() |
|
|