import torch.nn as nn from transformers import PreTrainedModel from .configuration_fireflies import FirefliesConfig class FirefliesModel(PreTrainedModel): config_class = FirefliesConfig def __init__(self, config): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.d_model) encoder_layer = nn.TransformerEncoderLayer(d_model=config.d_model, nhead=config.n_heads) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) self.fc = nn.Linear(config.d_model, config.vocab_size) self.init_weights() def forward(self, input_ids, attention_mask=None, labels=None): x = self.embedding(input_ids) x = x.transpose(0, 1) x = self.transformer(x) x = x.transpose(0, 1) logits = self.fc(x) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return {"loss": loss, "logits": logits}