| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class MiniGPTConfig(PretrainedConfig): |
| model_type = "mini_gpt" |
| def __init__(self, vocab_size=50257, n_positions=128, n_embd=128, n_layer=2, n_head=4, |
| pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.n_positions = n_positions |
| self.n_embd = n_embd |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.pad_token_id = pad_token_id |
| self.bos_token_id = bos_token_id |
| self.eos_token_id = eos_token_id |
|
|
| class MiniGPT(PreTrainedModel): |
| config_class = MiniGPTConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = nn.TransformerDecoder( |
| nn.TransformerDecoderLayer( |
| d_model=config.n_embd, |
| nhead=config.n_head, |
| dim_feedforward=config.n_embd * 4, |
| batch_first=True |
| ), |
| num_layers=config.n_layer |
| ) |
| self.embedding = nn.Embedding(config.vocab_size, config.n_embd) |
| self.pos_embedding = nn.Embedding(config.n_positions, config.n_embd) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.dropout = nn.Dropout(0.1) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| batch_size, seq_len = input_ids.size() |
| positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len) |
| |
| |
| x = self.embedding(input_ids) + self.pos_embedding(positions) |
| x = self.dropout(x) |
| |
| |
| causal_mask = torch.triu( |
| torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device, dtype=x.dtype), |
| diagonal=1 |
| ).unsqueeze(0).expand(self.config.n_head, -1, -1) |
| |
| |
| key_padding_mask = None |
| if attention_mask is not None: |
| key_padding_mask = (attention_mask == 0).to(torch.bool) |
| |
| |
| x = self.transformer( |
| tgt=x, |
| memory=x, |
| tgt_mask=causal_mask, |
| tgt_key_padding_mask=key_padding_mask |
| ) |
| logits = self.lm_head(x) |
| |
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| |
| loss_mask = (shift_labels != self.config.pad_token_id).float() |
| |
| loss_fct = nn.CrossEntropyLoss(reduction='none') |
| loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| loss = (loss * loss_mask.view(-1)).sum() / loss_mask.sum() |
| |
| return {"logits": logits, "loss": loss} |
|
|
| def generate(self, input_ids, max_length=50, **kwargs): |
| self.eval() |
| generated = input_ids |
| for _ in range(max_length): |
| outputs = self(generated)["logits"] |
| next_token = torch.argmax(outputs[:, -1, :], dim=-1).unsqueeze(-1) |
| generated = torch.cat([generated, next_token], dim=-1) |
| if next_token.item() == self.config.eos_token_id: |
| break |
| return generated |