| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| class EnigmaConfig(PretrainedConfig): |
| model_type = "enigma" |
| |
| def __init__(self, hidden_size=128, vocab_size=50257, num_hidden_layers=1, num_attention_heads=1, **kwargs): |
| super().__init__(**kwargs) |
| self.hidden_size = hidden_size |
| self.vocab_size = vocab_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.is_decoder = True |
|
|
| class EnigmaModel(PreTrainedModel): |
| config_class = EnigmaConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.linear = nn.Linear(config.hidden_size, config.hidden_size) |
| self.post_init() |
| |
| def forward(self, input_ids, **kwargs): |
| x = self.embedding(input_ids) |
| return self.linear(x) |
|
|
| from transformers.generation import GenerationMixin |
|
|
| class EnigmaForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = EnigmaConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = EnigmaModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
| |
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| hidden_states = self.model(input_ids) |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
| |
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask |
| } |
|
|
| |
| EnigmaConfig.register_for_auto_class() |
| EnigmaModel.register_for_auto_class("AutoModel") |
| EnigmaForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|