| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from .configuration_my_gpt import MyGPTConfig |
| from .untrained_model import GPTModel |
|
|
| import os |
| import sys |
|
|
| curr_dir = os.getcwd() |
| parent_dir = os.path.dirname(curr_dir) |
|
|
| sys.path.insert(0, parent_dir) |
|
|
|
|
| class MyGPTForCausalLM(PreTrainedModel): |
| config_class = MyGPTConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| |
| self.model = GPTModel({ |
| "vocab_size": config.vocab_size, |
| "context_length": config.context_length, |
| "emb_dim": config.emb_dim, |
| "n_heads": config.n_heads, |
| "n_layers": config.n_layers, |
| "drop_rate": config.drop_rate, |
| "qkv_bias": config.qkv_bias |
| }) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None): |
| logits = self.model(input_ids) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1) |
| ) |
|
|
| return { |
| "loss": loss, |
| "logits": logits, |
| } |
|
|