File size: 6,561 Bytes
92e51ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2Block, GPT2Model
from .attention import Attention
class GPT2AccelAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
scale = (self.head_dim**-0.5) if self.scale_attn_weights else 1.0
self.accel_attn = Attention(
self.num_heads, self.head_dim, scale, self.num_heads
)
def forward(
self,
hidden_states: torch.Tensor,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
past_key_value=None,
**kwargs,
):
if encoder_hidden_states is not None:
raise NotImplementedError("Cross attention not supported in accel mode")
qkv = self.c_attn(hidden_states)
query, key, value = qkv.split(self.split_size, dim=2)
# [B, T, H*D] -> [B, H, T, D]
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# flatten to [B*T, H, D]
bsz, num_heads, seq_len, head_dim = query.shape
q_flat = query.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
k_flat = key.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
v_flat = value.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
# ensure fp16
if q_flat.device.type == "cuda" and q_flat.dtype != torch.float16:
orig_dtype = q_flat.dtype
q_flat = q_flat.to(torch.float16)
k_flat = k_flat.to(torch.float16)
v_flat = v_flat.to(torch.float16)
else:
orig_dtype = q_flat.dtype
o_flat = self.accel_attn(q_flat, k_flat, v_flat) # [B*T, H, D]
if o_flat.dtype != orig_dtype:
o_flat = o_flat.to(orig_dtype)
# Reshape back: [B*T, H, D] -> [B, H, T, D]
attn_output = o_flat.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, None)
if output_attentions:
outputs += (None,)
return outputs
def _split_heads(self, tensor, num_heads, head_dim):
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, head_dim):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * head_dim,)
return tensor.view(new_shape)
class GPT2AccelBlock(GPT2Block):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx)
self.attn = GPT2AccelAttention(config, layer_idx)
class GPT2AccelModel(GPT2Model):
def __init__(self, config):
super().__init__(config)
self.h = nn.ModuleList(
[
GPT2AccelBlock(config, layer_idx=i)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if inputs_embeds is not None:
hidden_states = inputs_embeds
for block in self.h:
hidden_states = block(hidden_states)[0]
hidden_states = self.ln_f(hidden_states)
if return_dict:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
)
return (hidden_states,)
else:
return super().forward(
input_ids=input_ids,
past_key_values=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
|