File size: 1,473 Bytes
cac4140 46a0c5f 589a4c0 46a0c5f f8e5220 cac4140 589a4c0 cac4140 589a4c0 cac4140 589a4c0 cac4140 ef3097a cac4140 46a0c5f cac4140 46a0c5f cac4140 46a0c5f cac4140 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin
from .configuration_my_gpt import MyGPTConfig
from .untrained_model import GPTModel
import os
class MyGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MyGPTConfig
main_input_name = "input_ids"
def __init__(self, config):
super().__init__(config)
# Import your original GPTModel
self.model = GPTModel({
"vocab_size": config.vocab_size,
"context_length": config.max_position_embeddings,
"emb_dim": config.hidden_size,
"n_heads": config.num_attention_heads,
"n_layers": config.num_hidden_layers,
"drop_rate": config.drop_rate,
"qkv_bias": config.qkv_bias
})
self.post_init()
def forward(self, input_ids, labels=None, **kwargs):
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return CausalLMOutput(
loss=loss,
logits=logits,
)
|