from transformers import PreTrainedTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput import torch.nn as nn import torch class MSOTConfig(PretrainedConfig): model_type = "msot" def __init__(self, vocab_size=128, hidden_size=16, **kwargs): self.vocab_size = vocab_size self.hidden_size = hidden_size super().__init__(**kwargs) class MSOTModel(PreTrainedModel): config_class = MSOTConfig def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.config = config self.emb = nn.Embedding(config.vocab_size, config.hidden_size) self.l1 = nn.Linear(config.hidden_size, config.hidden_size) self.l2 = nn.Linear(config.hidden_size, config.hidden_size) self.l3 = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, input_ids, return_dict = None, **kwargs): hidden = self.emb(input_ids) a = self.l1(hidden) b = self.l2(hidden).transpose(-2, -1) c = self.l3(hidden) res = a @ b @ c # print("input:", input_ids) # print("output:", res) if not return_dict: return (res,) else: return BaseModelOutput(res) class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin): config_class = MSOTConfig def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.model = MSOTModel(config, **kwargs) self.l = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids, return_dict = None, labels = None, **kwargs): hidden = self.model(input_ids)[0] res = self.l(hidden) if labels is not None: loss = nn.functional.cross_entropy(res[:, :-1, :].contiguous().view(-1, self.model.config.vocab_size), labels[:, 1:].contiguous().view(-1)) print(loss) if not return_dict: return (loss, res) if labels is not None else (res,) else: return CausalLMOutput(logits=res, loss=loss) if labels is not None else CausalLMOutput(logits=res) def can_generate(self): return True def prepare_inputs_for_generation(self, input_ids, attention_mask = None, **kwargs): return {"input_ids": input_ids} class MSOTTokenizer(PreTrainedTokenizer): def get_vocab(self): return {chr(i): i for i in range(65536)} def _tokenize(self, text): return [c if ord(c) < 65536 else 0 for c in text] def _convert_token_to_id(self, token): return ord(token) def _convert_id_to_token(self, id): return chr(id) @property def vocab_size(self): return 65536 def save_vocabulary(self, *args, **kwargs): return () def gen128(model, input): tokens = torch.tensor([list(bytes(input,"ascii"))]) res = list(model.generate(tokens, max_new_tokens=50)[0]) return bytes(res).decode("utf-8") def gen65536(model, input): tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]]) res = list(model.generate(tokens, max_new_tokens=50)[0]) return "".join([chr(o) for o in res]) if __name__ == "__main__": MSOTConfig.register_for_auto_class() MSOTModel.register_for_auto_class("AutoModel") MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM") MSOTTokenizer.register_for_auto_class()