Enigma / enigma_module.py
dr-tkxx's picture
Upload 8 files
2258d52 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class EnigmaConfig(PretrainedConfig):
model_type = "enigma"
def __init__(self, hidden_size=128, vocab_size=50257, num_hidden_layers=1, num_attention_heads=1, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.is_decoder = True
class EnigmaModel(PreTrainedModel):
config_class = EnigmaConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.linear = nn.Linear(config.hidden_size, config.hidden_size)
self.post_init()
def forward(self, input_ids, **kwargs):
x = self.embedding(input_ids)
return self.linear(x)
from transformers.generation import GenerationMixin
class EnigmaForCausalLM(PreTrainedModel, GenerationMixin):
config_class = EnigmaConfig
def __init__(self, config):
super().__init__(config)
self.model = EnigmaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
hidden_states = self.model(input_ids)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# Registrando para permitir AutoModel, AutoConfig e AutoModelForCausalLM
EnigmaConfig.register_for_auto_class()
EnigmaModel.register_for_auto_class("AutoModel")
EnigmaForCausalLM.register_for_auto_class("AutoModelForCausalLM")