sarcastic-model / modeling_my_gpt.py
dev-das's picture
Update modeling_my_gpt.py
ef3097a verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin
from .configuration_my_gpt import MyGPTConfig
from .untrained_model import GPTModel
import os
class MyGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MyGPTConfig
main_input_name = "input_ids"
def __init__(self, config):
super().__init__(config)
# Import your original GPTModel
self.model = GPTModel({
"vocab_size": config.vocab_size,
"context_length": config.max_position_embeddings,
"emb_dim": config.hidden_size,
"n_heads": config.num_attention_heads,
"n_layers": config.num_hidden_layers,
"drop_rate": config.drop_rate,
"qkv_bias": config.qkv_bias
})
self.post_init()
def forward(self, input_ids, labels=None, **kwargs):
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return CausalLMOutput(
loss=loss,
logits=logits,
)