|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
from .configuration_tinygpt import TinyGPTConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TinyGPTConfig: |
|
|
model_type = "tinygpt" |
|
|
|
|
|
def __init__(self, |
|
|
vocab_size=30522, |
|
|
d_model=256, |
|
|
n_heads=4, |
|
|
n_layers=4, |
|
|
d_ff=1024, |
|
|
max_seq_len=256, |
|
|
**kwargs): |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.n_layers = n_layers |
|
|
self.d_ff = d_ff |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
setattr(self, k, v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TinyGPT(nn.Module): |
|
|
def __init__(self, vocab_size=30522, d_model=256, n_heads=4, |
|
|
n_layers=4, d_ff=1024, max_seq_len=256): |
|
|
x = self.ln_f(x) |
|
|
return self.head(x) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, d_model, n_heads, d_ff): |
|
|
super().__init__() |
|
|
ff_out = self.ff(x) |
|
|
x = self.ln2(x + ff_out) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TinyGPTForCausalLM(PreTrainedModel): |
|
|
config_class = TinyGPTConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.model = TinyGPT( |
|
|
vocab_size=config.vocab_size, |
|
|
d_model=config.d_model, |
|
|
n_heads=config.n_heads, |
|
|
n_layers=config.n_layers, |
|
|
d_ff=config.d_ff, |
|
|
max_seq_len=config.max_seq_len |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, labels=None): |
|
|
logits = self.model(input_ids) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = nn.CrossEntropyLoss()( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1) |
|
|
) |
|
|
|
|
|
return CausalLMOutput( |
|
|
logits=logits, |
|
|
loss=loss |
|
|
) |