File size: 3,972 Bytes
471a2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class MiniGPTConfig(PretrainedConfig):
    model_type = "mini_gpt"
    def __init__(self, vocab_size=50257, n_positions=128, n_embd=128, n_layer=2, n_head=4, 
                 pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

class MiniGPT(PreTrainedModel):
    config_class = MiniGPTConfig
    def __init__(self, config):
        super().__init__(config)
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=config.n_embd,
                nhead=config.n_head,
                dim_feedforward=config.n_embd * 4,
                batch_first=True
            ),
            num_layers=config.n_layer
        )
        self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_embedding = nn.Embedding(config.n_positions, config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.dropout = nn.Dropout(0.1)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        batch_size, seq_len = input_ids.size()
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)
        
        # Embeddings
        x = self.embedding(input_ids) + self.pos_embedding(positions)
        x = self.dropout(x)
        
        # Create causal mask (3D: [n_head, seq_len, seq_len])
        causal_mask = torch.triu(
            torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device, dtype=x.dtype),
            diagonal=1
        ).unsqueeze(0).expand(self.config.n_head, -1, -1)
        
        # Create key padding mask (2D: [batch_size, seq_len])
        key_padding_mask = None
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0).to(torch.bool)  # True for padded tokens
        
        # Pass to transformer
        x = self.transformer(
            tgt=x,
            memory=x,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=key_padding_mask
        )
        logits = self.lm_head(x)
        
        loss = None
        if labels is not None:
            # Shift logits and labels for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Create loss mask to ignore padding tokens
            loss_mask = (shift_labels != self.config.pad_token_id).float()
            
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
            loss = (loss * loss_mask.view(-1)).sum() / loss_mask.sum()
        
        return {"logits": logits, "loss": loss}

    def generate(self, input_ids, max_length=50, **kwargs):
        self.eval()
        generated = input_ids
        for _ in range(max_length):
            outputs = self(generated)["logits"]
            next_token = torch.argmax(outputs[:, -1, :], dim=-1).unsqueeze(-1)
            generated = torch.cat([generated, next_token], dim=-1)
            if next_token.item() == self.config.eos_token_id:
                break
        return generated