import torch import torch.nn as nn import math from transformers import PreTrainedModel from transformers.modeling_utils import PretrainedConfig class TransformerLMConfig(PretrainedConfig): model_type = "transformer_lm" def __init__( self, vocab_size=40, hidden_size=256, num_hidden_layers=4, num_attention_heads=8, intermediate_size=1024, max_position_embeddings=64, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings class TransformerLM(PreTrainedModel): config_class = TransformerLMConfig def __init__(self, config): super().__init__(config) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers) self.output_layer = nn.Linear(config.hidden_size, config.vocab_size) self.max_position_embeddings = config.max_position_embeddings def forward(self, input_ids, attention_mask=None, labels=None): seq_len = input_ids.size(1) pos = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0) # Embedding + positional encoding src_emb = self.embedding(input_ids) * math.sqrt(self.config.hidden_size) pos_emb = self.pos_embedding(pos) src_emb = src_emb + pos_emb # Create key padding mask for transformer (True where we should mask) if attention_mask is not None: # Transformer expects True for positions to mask src_key_padding_mask = ~attention_mask.bool() else: src_key_padding_mask = None # Transformer encoder output = self.transformer_encoder(src_emb, src_key_padding_mask=src_key_padding_mask) # Output projection logits = self.output_layer(output) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return { "loss": loss, "logits": logits } def prepare_inputs_for_generation(self, input_ids, **kwargs): # Only last token for inputs_ids if past is defined in kwargs if "past_key_values" in kwargs: input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is not None: attention_mask = attention_mask return { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, }