sarcastic-model / modeling_my_gpt.py
dev-das's picture
properly configure relative paths
f8e5220 verified
raw
history blame
1.35 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_my_gpt import MyGPTConfig
from .untrained_model import GPTModel
import os
import sys
curr_dir = os.getcwd()
parent_dir = os.path.dirname(curr_dir)
sys.path.insert(0, parent_dir)
class MyGPTForCausalLM(PreTrainedModel):
config_class = MyGPTConfig
def __init__(self, config):
super().__init__(config)
# Import your original GPTModel
self.model = GPTModel({
"vocab_size": config.vocab_size,
"context_length": config.context_length,
"emb_dim": config.emb_dim,
"n_heads": config.n_heads,
"n_layers": config.n_layers,
"drop_rate": config.drop_rate,
"qkv_bias": config.qkv_bias
})
self.post_init()
def forward(self, input_ids, labels=None):
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return {
"loss": loss,
"logits": logits,
}