tinygpt-base-model / modeling_tinygpt.py
Abdurrahmanesc's picture
Update modeling_tinygpt.py
707524c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from .configuration_tinygpt import TinyGPTConfig #to look inside the folder
# -------------------------
# TinyGPTConfig (Required)
# -------------------------
class TinyGPTConfig:
model_type = "tinygpt"
def __init__(self,
vocab_size=30522,
d_model=256,
n_heads=4,
n_layers=4,
d_ff=1024,
max_seq_len=256,
**kwargs):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.d_ff = d_ff
self.max_seq_len = max_seq_len
# store additional HF keys
for k, v in kwargs.items():
setattr(self, k, v)
# -------------------------
# Your Original TinyGPT Core
# -------------------------
class TinyGPT(nn.Module):
def __init__(self, vocab_size=30522, d_model=256, n_heads=4,
n_layers=4, d_ff=1024, max_seq_len=256):
x = self.ln_f(x)
return self.head(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
ff_out = self.ff(x)
x = self.ln2(x + ff_out)
return x
# -------------------------
# HF Wrapper: TinyGPTForCausalLM
# -------------------------
class TinyGPTForCausalLM(PreTrainedModel):
config_class = TinyGPTConfig
def __init__(self, config):
super().__init__(config)
self.model = TinyGPT(
vocab_size=config.vocab_size,
d_model=config.d_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
d_ff=config.d_ff,
max_seq_len=config.max_seq_len
)
self.post_init()
def forward(self, input_ids, labels=None):
logits = self.model(input_ids)
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
return CausalLMOutput(
logits=logits,
loss=loss
)