tiny-gpt / modeling_tiny_gpt.py
alainbrown's picture
Publish trained storyteller checkpoint
f167ee5 verified
Raw
History Blame Contribute Delete
1.78 kB
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
from .configuration_tiny_gpt import TinyGPTConfig
from .model import GPTModel
class TinyGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = TinyGPTConfig
main_input_name = "input_ids"
_tied_weights_keys = {"core_model.linear.weight": "core_model.token_embedding.weight"}
def __init__(self, config):
super().__init__(config)
self.core_model = GPTModel(
context_size=config.context_size,
vocab_size=config.vocab_size,
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
dropout=config.dropout,
)
self.post_init()
def get_input_embeddings(self):
return self.core_model.token_embedding
def set_input_embeddings(self, value):
self.core_model.token_embedding = value
def get_output_embeddings(self):
return self.core_model.linear
def set_output_embeddings(self, new_embeddings):
self.core_model.linear = new_embeddings
def forward(self, input_ids=None, labels=None, **kwargs):
if input_ids is None:
raise ValueError("input_ids must be provided")
logits = self.core_model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
return CausalLMOutput(loss=loss, logits=logits)