import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from transformers.generation import GenerationMixin from .configuration_my_gpt import MyGPTConfig from .untrained_model import GPTModel import os class MyGPTForCausalLM(PreTrainedModel, GenerationMixin): config_class = MyGPTConfig main_input_name = "input_ids" def __init__(self, config): super().__init__(config) # Import your original GPTModel self.model = GPTModel({ "vocab_size": config.vocab_size, "context_length": config.max_position_embeddings, "emb_dim": config.hidden_size, "n_heads": config.num_attention_heads, "n_layers": config.num_hidden_layers, "drop_rate": config.drop_rate, "qkv_bias": config.qkv_bias }) self.post_init() def forward(self, input_ids, labels=None, **kwargs): logits = self.model(input_ids) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) return CausalLMOutput( loss=loss, logits=logits, )