| import torch
|
| import torch.nn as nn
|
| from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
| from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
|
|
|
| from .attention import Attention
|
|
|
|
|
| class GPT2AccelAttention(nn.Module):
|
| def __init__(self, config, layer_idx=None):
|
| super().__init__()
|
| self.config = config
|
| self.layer_idx = layer_idx
|
|
|
| max_positions = config.max_position_embeddings
|
| self.register_buffer(
|
| "bias",
|
| torch.tril(
|
| torch.ones((max_positions, max_positions), dtype=torch.bool)
|
| ).view(1, 1, max_positions, max_positions),
|
| persistent=False,
|
| )
|
| self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
|
|
| self.embed_dim = config.hidden_size
|
| self.num_heads = config.num_attention_heads
|
| self.head_dim = self.embed_dim // self.num_heads
|
| self.split_size = self.embed_dim
|
|
|
| if self.head_dim * self.num_heads != self.embed_dim:
|
| raise ValueError(
|
| f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| f" {self.num_heads})."
|
| )
|
|
|
| self.scale_attn_weights = config.scale_attn_weights
|
|
|
| self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
| self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
|
|
| self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
|
| scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
|
| self.accel_attn = Attention(
|
| self.num_heads, self.head_dim, scale, self.num_heads
|
| )
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| layer_past=None,
|
| attention_mask=None,
|
| head_mask=None,
|
| encoder_hidden_states=None,
|
| encoder_attention_mask=None,
|
| use_cache=False,
|
| output_attentions=False,
|
| past_key_value=None,
|
| **kwargs,
|
| ):
|
| if encoder_hidden_states is not None:
|
| raise NotImplementedError("Cross attention not supported in accel mode")
|
|
|
| qkv = self.c_attn(hidden_states)
|
| query, key, value = qkv.split(self.split_size, dim=2)
|
|
|
|
|
| query = self._split_heads(query, self.num_heads, self.head_dim)
|
| key = self._split_heads(key, self.num_heads, self.head_dim)
|
| value = self._split_heads(value, self.num_heads, self.head_dim)
|
|
|
|
|
| bsz, num_heads, seq_len, head_dim = query.shape
|
| q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
| k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
| v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
|
|
|
|
|
| if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
|
| orig_dtype = q_flat.dtype
|
| q_flat = q_flat.to(torch.float16)
|
| k_flat = k_flat.to(torch.float16)
|
| v_flat = v_flat.to(torch.float16)
|
| else:
|
| orig_dtype = q_flat.dtype
|
|
|
| o_flat = self.accel_attn(q_flat, k_flat, v_flat)
|
|
|
| if o_flat.dtype != orig_dtype:
|
| o_flat = o_flat.to(orig_dtype)
|
|
|
|
|
| attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
|
|
|
| attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
|
|
| attn_output = self.c_proj(attn_output)
|
| attn_output = self.resid_dropout(attn_output)
|
|
|
| outputs = (attn_output, None)
|
| if output_attentions:
|
| outputs += (None,)
|
|
|
| return outputs
|
|
|
| def _split_heads(self, tensor, num_heads, head_dim):
|
| new_shape = tensor.size()[:-1] + (num_heads, head_dim)
|
| tensor = tensor.view(new_shape)
|
| return tensor.permute(0, 2, 1, 3)
|
|
|
| def _merge_heads(self, tensor, num_heads, head_dim):
|
| tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
| new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
|
| return tensor.view(new_shape)
|
|
|
|
|
| class GPT2AccelBlock(GPT2Block):
|
| def __init__(self, config, layer_idx=None):
|
| super().__init__(config, layer_idx)
|
| self.attn = GPT2AccelAttention(config, layer_idx)
|
|
|
|
|
| class GPT2AccelModel(GPT2Model):
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.h = nn.ModuleList(
|
| [
|
| GPT2AccelBlock(config, layer_idx=i)
|
| for i in range(config.num_hidden_layers)
|
| ]
|
| )
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| past_key_values=None,
|
| attention_mask=None,
|
| token_type_ids=None,
|
| position_ids=None,
|
| head_mask=None,
|
| inputs_embeds=None,
|
| encoder_hidden_states=None,
|
| encoder_attention_mask=None,
|
| use_cache=None,
|
| output_attentions=None,
|
| output_hidden_states=None,
|
| return_dict=None,
|
| ):
|
| if inputs_embeds is not None:
|
| hidden_states = inputs_embeds
|
|
|
| for block in self.h:
|
| hidden_states = block(hidden_states)[0]
|
|
|
| hidden_states = self.ln_f(hidden_states)
|
|
|
| if return_dict:
|
| return BaseModelOutputWithPastAndCrossAttentions(
|
| last_hidden_state=hidden_states,
|
| past_key_values=None,
|
| hidden_states=None,
|
| attentions=None,
|
| )
|
| return (hidden_states,)
|
| else:
|
| return super().forward(
|
| input_ids=input_ids,
|
| past_key_values=None,
|
| attention_mask=attention_mask,
|
| token_type_ids=token_type_ids,
|
| position_ids=position_ids,
|
| head_mask=head_mask,
|
| inputs_embeds=None,
|
| encoder_hidden_states=encoder_hidden_states,
|
| encoder_attention_mask=encoder_attention_mask,
|
| use_cache=False,
|
| output_attentions=output_attentions,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
|
|