| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| | from transformers.generation import GenerationMixin |
| |
|
| |
|
| | from .configuration_my_gpt import MyGPTConfig |
| | from .untrained_model import GPTModel |
| |
|
| | import os |
| |
|
| |
|
| |
|
| | class MyGPTForCausalLM(PreTrainedModel, GenerationMixin): |
| | config_class = MyGPTConfig |
| | main_input_name = "input_ids" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| |
|
| | |
| | self.model = GPTModel({ |
| | "vocab_size": config.vocab_size, |
| | "context_length": config.max_position_embeddings, |
| | "emb_dim": config.hidden_size, |
| | "n_heads": config.num_attention_heads, |
| | "n_layers": config.num_hidden_layers, |
| | "drop_rate": config.drop_rate, |
| | "qkv_bias": config.qkv_bias |
| | }) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, input_ids, labels=None, **kwargs): |
| | logits = self.model(input_ids) |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1) |
| | ) |
| | |
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | ) |
| |
|
| |
|