File size: 3,976 Bytes
66d4b44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import math
from transformers import PreTrainedModel
from transformers.modeling_utils import PretrainedConfig

class TransformerLMConfig(PretrainedConfig):
    model_type = "transformer_lm"
    
    def __init__(
        self,
        vocab_size=40,
        hidden_size=256,
        num_hidden_layers=4,
        num_attention_heads=8,
        intermediate_size=1024,
        max_position_embeddings=64,
        pad_token_id=0,
        bos_token_id=1,
        eos_token_id=2,
        **kwargs
    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs
        )
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings

class TransformerLM(PreTrainedModel):
    config_class = TransformerLMConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            dim_feedforward=config.intermediate_size,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size)
        
        self.max_position_embeddings = config.max_position_embeddings
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        seq_len = input_ids.size(1)
        pos = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        
        # Embedding + positional encoding
        src_emb = self.embedding(input_ids) * math.sqrt(self.config.hidden_size)
        pos_emb = self.pos_embedding(pos)
        src_emb = src_emb + pos_emb
        
        # Create key padding mask for transformer (True where we should mask)
        if attention_mask is not None:
            # Transformer expects True for positions to mask
            src_key_padding_mask = ~attention_mask.bool()
        else:
            src_key_padding_mask = None
        
        # Transformer encoder
        output = self.transformer_encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
        
        # Output projection
        logits = self.output_layer(output)
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return {
            "loss": loss,
            "logits": logits
        }
        
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        # Only last token for inputs_ids if past is defined in kwargs
        if "past_key_values" in kwargs:
            input_ids = input_ids[:, -1].unsqueeze(-1)
        
        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)
        
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is not None:
            attention_mask = attention_mask
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
        }