Sage / modeling_transformer_lm.py
itriedcoding's picture
Upload folder using huggingface_hub
66d4b44 verified
Raw
History Blame Contribute Delete
3.98 kB
import torch
import torch.nn as nn
import math
from transformers import PreTrainedModel
from transformers.modeling_utils import PretrainedConfig
class TransformerLMConfig(PretrainedConfig):
model_type = "transformer_lm"
def __init__(
self,
vocab_size=40,
hidden_size=256,
num_hidden_layers=4,
num_attention_heads=8,
intermediate_size=1024,
max_position_embeddings=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
class TransformerLM(PreTrainedModel):
config_class = TransformerLMConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
self.output_layer = nn.Linear(config.hidden_size, config.vocab_size)
self.max_position_embeddings = config.max_position_embeddings
def forward(self, input_ids, attention_mask=None, labels=None):
seq_len = input_ids.size(1)
pos = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
# Embedding + positional encoding
src_emb = self.embedding(input_ids) * math.sqrt(self.config.hidden_size)
pos_emb = self.pos_embedding(pos)
src_emb = src_emb + pos_emb
# Create key padding mask for transformer (True where we should mask)
if attention_mask is not None:
# Transformer expects True for positions to mask
src_key_padding_mask = ~attention_mask.bool()
else:
src_key_padding_mask = None
# Transformer encoder
output = self.transformer_encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
# Output projection
logits = self.output_layer(output)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {
"loss": loss,
"logits": logits
}
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# Only last token for inputs_ids if past is defined in kwargs
if "past_key_values" in kwargs:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is not None:
attention_mask = attention_mask
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}