| from torch import nn |
|
|
| from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase |
|
|
|
|
| class GPT2Block(GPT2BlockBase): |
| def forward(self, x, layer_past=None, |
| attention_mask=None, head_mask=None, use_cache=False, |
| encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None): |
| |
| x = self.ln_1(x) |
| output_attn = self.attn( |
| x, layer_past=layer_past, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| use_cache=use_cache) |
| |
| a = output_attn[0] |
| x = x + a |
|
|
| m = self.mlp(self.ln_2(x)) |
| x = x + m |
|
|
| outputs = (x,) + output_attn[1:] |
| return outputs |
|
|
|
|
| class GPT2LMHeadModel(GPT2LMHeadModelBase): |
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)]) |