programmerworld commited on
Commit
f0e6efa
·
verified ·
1 Parent(s): b5529b4

Upload mini_gpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mini_gpt.py +98 -0
mini_gpt.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+ class MiniGPTConfig(PretrainedConfig):
6
+ model_type = "mini_gpt"
7
+ def __init__(self, vocab_size=50257, n_positions=128, n_embd=128, n_layer=2, n_head=4,
8
+ pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.vocab_size = vocab_size
11
+ self.n_positions = n_positions
12
+ self.n_embd = n_embd
13
+ self.n_layer = n_layer
14
+ self.n_head = n_head
15
+ self.pad_token_id = pad_token_id
16
+ self.bos_token_id = bos_token_id
17
+ self.eos_token_id = eos_token_id
18
+
19
+ class MiniGPT(PreTrainedModel):
20
+ config_class = MiniGPTConfig
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+ self.transformer = nn.TransformerDecoder(
24
+ nn.TransformerDecoderLayer(
25
+ d_model=config.n_embd,
26
+ nhead=config.n_head,
27
+ dim_feedforward=config.n_embd * 4,
28
+ batch_first=True
29
+ ),
30
+ num_layers=config.n_layer
31
+ )
32
+ self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
33
+ self.pos_embedding = nn.Embedding(config.n_positions, config.n_embd)
34
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
35
+ self.dropout = nn.Dropout(0.1)
36
+
37
+ # Initialize weights
38
+ self.apply(self._init_weights)
39
+
40
+ def _init_weights(self, module):
41
+ if isinstance(module, (nn.Linear, nn.Embedding)):
42
+ module.weight.data.normal_(mean=0.0, std=0.02)
43
+ if isinstance(module, nn.Linear) and module.bias is not None:
44
+ module.bias.data.zero_()
45
+
46
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
47
+ batch_size, seq_len = input_ids.size()
48
+ positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)
49
+
50
+ # Embeddings
51
+ x = self.embedding(input_ids) + self.pos_embedding(positions)
52
+ x = self.dropout(x)
53
+
54
+ # Create causal mask (3D: [n_head, seq_len, seq_len])
55
+ causal_mask = torch.triu(
56
+ torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device, dtype=x.dtype),
57
+ diagonal=1
58
+ ).unsqueeze(0).expand(self.config.n_head, -1, -1)
59
+
60
+ # Create key padding mask (2D: [batch_size, seq_len])
61
+ key_padding_mask = None
62
+ if attention_mask is not None:
63
+ key_padding_mask = (attention_mask == 0).to(torch.bool) # True for padded tokens
64
+
65
+ # Pass to transformer
66
+ x = self.transformer(
67
+ tgt=x,
68
+ memory=x,
69
+ tgt_mask=causal_mask,
70
+ tgt_key_padding_mask=key_padding_mask
71
+ )
72
+ logits = self.lm_head(x)
73
+
74
+ loss = None
75
+ if labels is not None:
76
+ # Shift logits and labels for next-token prediction
77
+ shift_logits = logits[..., :-1, :].contiguous()
78
+ shift_labels = labels[..., 1:].contiguous()
79
+
80
+ # Create loss mask to ignore padding tokens
81
+ loss_mask = (shift_labels != self.config.pad_token_id).float()
82
+
83
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
84
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
85
+ loss = (loss * loss_mask.view(-1)).sum() / loss_mask.sum()
86
+
87
+ return {"logits": logits, "loss": loss}
88
+
89
+ def generate(self, input_ids, max_length=50, **kwargs):
90
+ self.eval()
91
+ generated = input_ids
92
+ for _ in range(max_length):
93
+ outputs = self(generated)["logits"]
94
+ next_token = torch.argmax(outputs[:, -1, :], dim=-1).unsqueeze(-1)
95
+ generated = torch.cat([generated, next_token], dim=-1)
96
+ if next_token.item() == self.config.eos_token_id:
97
+ break
98
+ return generated