import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast class EnigmaConfig(PretrainedConfig): model_type = "enigma" def __init__(self, hidden_size=128, vocab_size=50257, num_hidden_layers=1, num_attention_heads=1, **kwargs): super().__init__(**kwargs) self.hidden_size = hidden_size self.vocab_size = vocab_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.is_decoder = True class EnigmaModel(PreTrainedModel): config_class = EnigmaConfig def __init__(self, config): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.linear = nn.Linear(config.hidden_size, config.hidden_size) self.post_init() def forward(self, input_ids, **kwargs): x = self.embedding(input_ids) return self.linear(x) from transformers.generation import GenerationMixin class EnigmaForCausalLM(PreTrainedModel, GenerationMixin): config_class = EnigmaConfig def __init__(self, config): super().__init__(config) self.model = EnigmaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): hidden_states = self.model(input_ids) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): return { "input_ids": input_ids, "attention_mask": attention_mask } # Registrando para permitir AutoModel, AutoConfig e AutoModelForCausalLM EnigmaConfig.register_for_auto_class() EnigmaModel.register_for_auto_class("AutoModel") EnigmaForCausalLM.register_for_auto_class("AutoModelForCausalLM")