|
|
import torch |
|
|
import torch.jit |
|
|
from torch import nn |
|
|
import torch.nn.functional as nnF |
|
|
|
|
|
from torch import Tensor |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import warnings |
|
|
|
|
|
class MultiheadAttention(nn.MultiheadAttention): |
|
|
_FLOAT_MODULE = nn.MultiheadAttention |
|
|
|
|
|
r"""Quantizable implementation of the MultiheadAttention. |
|
|
|
|
|
Note:: |
|
|
Please, refer to :class:`~torch.nn.MultiheadAttention` for more |
|
|
information |
|
|
|
|
|
Allows the model to jointly attend to information from different |
|
|
representation subspaces. |
|
|
See reference: Attention Is All You Need |
|
|
|
|
|
The original MHA module is not quantizable. |
|
|
This reimplements it by explicitly instantiating the linear layers. |
|
|
|
|
|
.. math:: |
|
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O |
|
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) |
|
|
|
|
|
Args: |
|
|
embed_dim: total dimension of the model. |
|
|
num_heads: parallel attention heads. |
|
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0. |
|
|
bias: add bias as module parameter. Default: True. |
|
|
add_bias_kv: add bias to the key and value sequences at dim=0. |
|
|
add_zero_attn: add a new batch of zeros to the key and |
|
|
value sequences at dim=1. |
|
|
kdim: total number of features in key. Default: None. |
|
|
vdim: total number of features in value. Default: None. |
|
|
batch_first: If ``True``, then the input and output tensors are provided |
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature). |
|
|
|
|
|
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set |
|
|
to :attr:`embed_dim` such that query, key, and value have the same |
|
|
number of features. |
|
|
|
|
|
Examples:: |
|
|
|
|
|
>>> import torch.nn.quantizable as nnqa |
|
|
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) |
|
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value) |
|
|
|
|
|
Note:: |
|
|
Please, follow the quantization flow to convert the quantizable MHA. |
|
|
""" |
|
|
__constants__ = ['batch_first'] |
|
|
|
|
|
def __init__(self, embed_dim: int, num_heads: int, |
|
|
dropout: float = 0., bias: bool = True, |
|
|
add_bias_kv: bool = False, add_zero_attn: bool = False, |
|
|
kdim: int = None, vdim: int = None, batch_first: bool = False, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
|
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, |
|
|
bias, add_bias_kv, |
|
|
add_zero_attn, kdim, vdim, batch_first, |
|
|
**factory_kwargs) |
|
|
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) |
|
|
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs) |
|
|
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs) |
|
|
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) |
|
|
|
|
|
|
|
|
self.q_scaling_product = torch.nn.quantized.FloatFunctional() |
|
|
|
|
|
|
|
|
|
|
|
self.quant_attn_output = torch.ao.quantization.QuantStub() |
|
|
self.quant_attn_output_weights = torch.ao.quantization.QuantStub() |
|
|
self.dequant_q = torch.ao.quantization.DeQuantStub() |
|
|
self.dequant_k = torch.ao.quantization.DeQuantStub() |
|
|
self.dequant_v = torch.ao.quantization.DeQuantStub() |
|
|
|
|
|
def _get_name(self): |
|
|
return 'QuantizableMultiheadAttention' |
|
|
|
|
|
@classmethod |
|
|
def from_float(cls, other): |
|
|
assert type(other) == cls._FLOAT_MODULE |
|
|
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" |
|
|
|
|
|
observed = cls(other.embed_dim, other.num_heads, other.dropout, |
|
|
(other.in_proj_bias is not None), |
|
|
(other.bias_k is not None), |
|
|
other.add_zero_attn, other.kdim, other.vdim) |
|
|
observed.bias_k = other.bias_k |
|
|
observed.bias_v = other.bias_v |
|
|
observed.qconfig = other.qconfig |
|
|
|
|
|
|
|
|
|
|
|
observed.out_proj.weight = other.out_proj.weight |
|
|
observed.out_proj.bias = other.out_proj.bias |
|
|
if other._qkv_same_embed_dim: |
|
|
|
|
|
bias = other.in_proj_bias |
|
|
_start = 0 |
|
|
_end = _start + other.embed_dim |
|
|
weight = other.in_proj_weight[_start:_end, :] |
|
|
if bias is not None: |
|
|
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) |
|
|
observed.linear_Q.weight = torch.nn.Parameter(weight, |
|
|
weight.requires_grad) |
|
|
observed.linear_Q.bias = bias |
|
|
|
|
|
bias = other.in_proj_bias |
|
|
_start = _end |
|
|
_end = _start + other.embed_dim |
|
|
weight = other.in_proj_weight[_start:_end, :] |
|
|
if bias is not None: |
|
|
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) |
|
|
observed.linear_K.weight = torch.nn.Parameter(weight, |
|
|
weight.requires_grad) |
|
|
observed.linear_K.bias = bias |
|
|
|
|
|
bias = other.in_proj_bias |
|
|
_start = _end |
|
|
weight = other.in_proj_weight[_start:, :] |
|
|
if bias is not None: |
|
|
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) |
|
|
observed.linear_V.weight = torch.nn.Parameter(weight, |
|
|
weight.requires_grad) |
|
|
observed.linear_V.bias = bias |
|
|
else: |
|
|
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) |
|
|
observed.linear_K.weight = nn.Parameter(other.k_proj_weight) |
|
|
observed.linear_V.weight = nn.Parameter(other.v_proj_weight) |
|
|
if other.in_proj_bias is None: |
|
|
observed.linear_Q.bias = None |
|
|
observed.linear_K.bias = None |
|
|
observed.linear_V.bias = None |
|
|
else: |
|
|
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim]) |
|
|
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)]) |
|
|
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):]) |
|
|
observed.eval() |
|
|
|
|
|
observed = torch.ao.quantization.prepare(observed, inplace=True) |
|
|
return observed |
|
|
|
|
|
@torch.jit.unused |
|
|
def dequantize(self): |
|
|
r"""Utility to convert the quantized MHA back to float. |
|
|
|
|
|
The motivation for this is that it is not trivial to conver the weights |
|
|
from the format that is used in the quantized version back to the |
|
|
float. |
|
|
""" |
|
|
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout, |
|
|
(self.in_proj_bias is not None), |
|
|
(self.bias_k is not None), |
|
|
self.add_zero_attn, self.kdim, self.vdim, self.batch_first) |
|
|
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim |
|
|
if self.bias_k is not None: |
|
|
fp.bias_k = nn.Parameter(self.bias_k.dequantize()) |
|
|
if self.bias_v is not None: |
|
|
fp.bias_v = nn.Parameter(self.bias_v.dequantize()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w, b = self.out_proj._weight_bias() |
|
|
fp.out_proj.weight = nn.Parameter(w.dequantize()) |
|
|
if b is not None: |
|
|
fp.out_proj.bias = nn.Parameter(b) |
|
|
|
|
|
wQ, bQ = self.linear_Q._weight_bias() |
|
|
wQ = wQ.dequantize() |
|
|
wK, bK = self.linear_K._weight_bias() |
|
|
wK = wK.dequantize() |
|
|
wV, bV = self.linear_V._weight_bias() |
|
|
wV = wV.dequantize() |
|
|
if fp._qkv_same_embed_dim: |
|
|
|
|
|
_start = 0 |
|
|
_end = _start + fp.embed_dim |
|
|
fp.in_proj_weight[_start:_end, :] = wQ |
|
|
if fp.in_proj_bias is not None: |
|
|
assert all(bQ == 0) |
|
|
fp.in_proj_bias[_start:_end] = bQ |
|
|
|
|
|
_start = _end |
|
|
_end = _start + fp.embed_dim |
|
|
fp.in_proj_weight[_start:_end, :] = wK |
|
|
if fp.in_proj_bias is not None: |
|
|
assert all(bK == 0) |
|
|
fp.in_proj_bias[_start:_end] = bK |
|
|
|
|
|
_start = _end |
|
|
fp.in_proj_weight[_start:, :] = wV |
|
|
if fp.in_proj_bias is not None: |
|
|
assert all(bV == 0) |
|
|
fp.in_proj_bias[_start:] = bV |
|
|
else: |
|
|
fp.q_proj_weight = nn.Parameter(wQ) |
|
|
fp.k_proj_weight = nn.Parameter(wK) |
|
|
fp.v_proj_weight = nn.Parameter(wV) |
|
|
if fp.in_proj_bias is None: |
|
|
self.linear_Q.bias = None |
|
|
self.linear_K.bias = None |
|
|
self.linear_V.bias = None |
|
|
else: |
|
|
fp.in_proj_bias[0:fp.embed_dim] = bQ |
|
|
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK |
|
|
fp.in_proj_bias[(fp.embed_dim * 2):] = bV |
|
|
|
|
|
return fp |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_observed(cls, other): |
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError("It looks like you are trying to prepare an " |
|
|
"MHA module. Please, see " |
|
|
"the examples on quantizable MHAs.") |
|
|
|
|
|
def forward(self, |
|
|
query: Tensor, |
|
|
key: Tensor, |
|
|
value: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
need_weights: bool = True, |
|
|
attn_mask: Optional[Tensor] = None, |
|
|
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
r""" |
|
|
Note:: |
|
|
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more |
|
|
information |
|
|
|
|
|
Args: |
|
|
query, key, value: map a query and a set of key-value pairs to an output. |
|
|
See "Attention Is All You Need" for more details. |
|
|
key_padding_mask: if provided, specified padding elements in the key will |
|
|
be ignored by the attention. When given a binary mask and a value is True, |
|
|
the corresponding value on the attention layer will be ignored. When given |
|
|
a byte mask and a value is non-zero, the corresponding value on the attention |
|
|
layer will be ignored |
|
|
need_weights: output attn_output_weights. |
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all |
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch. |
|
|
|
|
|
Shape: |
|
|
- Inputs: |
|
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is |
|
|
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. |
|
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is |
|
|
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. |
|
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is |
|
|
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. |
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. |
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position |
|
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the |
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. |
|
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. |
|
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, |
|
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked |
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend |
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` |
|
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor |
|
|
is provided, it will be added to the attention weight. |
|
|
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across |
|
|
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an |
|
|
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) |
|
|
|
|
|
- Outputs: |
|
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, |
|
|
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. |
|
|
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged |
|
|
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, |
|
|
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per |
|
|
head of shape :math:`(N, num_heads, L, S)`. |
|
|
""" |
|
|
return self._forward_impl(query, key, value, key_padding_mask, |
|
|
need_weights, attn_mask, average_attn_weights) |
|
|
|
|
|
def _forward_impl(self, |
|
|
query: Tensor, |
|
|
key: Tensor, |
|
|
value: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
need_weights: bool = True, |
|
|
attn_mask: Optional[Tensor] = None, |
|
|
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static_k = None |
|
|
static_v = None |
|
|
|
|
|
if self.batch_first: |
|
|
query, key, value = [x.transpose(0, 1) for x in (query, key, value)] |
|
|
|
|
|
tgt_len, bsz, embed_dim_to_check = query.size() |
|
|
assert self.embed_dim == embed_dim_to_check |
|
|
|
|
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1) |
|
|
|
|
|
head_dim = self.embed_dim // self.num_heads |
|
|
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" |
|
|
scaling = float(head_dim) ** -0.5 |
|
|
|
|
|
q = self.linear_Q(query) |
|
|
k = self.linear_K(key) |
|
|
v = self.linear_V(value) |
|
|
|
|
|
q = self.q_scaling_product.mul_scalar(q, scaling) |
|
|
|
|
|
if attn_mask is not None: |
|
|
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ |
|
|
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ |
|
|
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) |
|
|
if attn_mask.dtype == torch.uint8: |
|
|
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") |
|
|
attn_mask = attn_mask.to(torch.bool) |
|
|
|
|
|
if attn_mask.dim() == 2: |
|
|
attn_mask = attn_mask.unsqueeze(0) |
|
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: |
|
|
raise RuntimeError('The size of the 2D attn_mask is not correct.') |
|
|
elif attn_mask.dim() == 3: |
|
|
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]: |
|
|
raise RuntimeError('The size of the 3D attn_mask is not correct.') |
|
|
else: |
|
|
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) |
|
|
|
|
|
|
|
|
|
|
|
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: |
|
|
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") |
|
|
key_padding_mask = key_padding_mask.to(torch.bool) |
|
|
if self.bias_k is not None and self.bias_v is not None: |
|
|
if static_k is None and static_v is None: |
|
|
|
|
|
|
|
|
|
|
|
bias_k = self.bias_k |
|
|
assert bias_k is not None |
|
|
bias_v = self.bias_v |
|
|
assert bias_v is not None |
|
|
|
|
|
k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) |
|
|
v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) |
|
|
if attn_mask is not None: |
|
|
attn_mask = nnF.pad(attn_mask, (0, 1)) |
|
|
if key_padding_mask is not None: |
|
|
key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) |
|
|
else: |
|
|
assert static_k is None, "bias cannot be added to static key." |
|
|
assert static_v is None, "bias cannot be added to static value." |
|
|
else: |
|
|
assert self.bias_k is None |
|
|
assert self.bias_v is None |
|
|
|
|
|
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) |
|
|
if k is not None: |
|
|
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) |
|
|
if v is not None: |
|
|
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) |
|
|
|
|
|
if static_k is not None: |
|
|
assert static_k.size(0) == bsz * self.num_heads |
|
|
assert static_k.size(2) == head_dim |
|
|
k = static_k |
|
|
|
|
|
if static_v is not None: |
|
|
assert static_v.size(0) == bsz * self.num_heads |
|
|
assert static_v.size(2) == head_dim |
|
|
v = static_v |
|
|
|
|
|
src_len = k.size(1) |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
assert key_padding_mask.size(0) == bsz |
|
|
assert key_padding_mask.size(1) == src_len |
|
|
|
|
|
if self.add_zero_attn: |
|
|
src_len += 1 |
|
|
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) |
|
|
if k.is_quantized: |
|
|
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype) |
|
|
k = torch.cat([k, k_zeros], dim=1) |
|
|
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) |
|
|
if v.is_quantized: |
|
|
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype) |
|
|
v = torch.cat([v, v_zeros], dim=1) |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_mask = nnF.pad(attn_mask, (0, 1)) |
|
|
if key_padding_mask is not None: |
|
|
key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) |
|
|
|
|
|
|
|
|
q = self.dequant_q(q) |
|
|
k = self.dequant_k(k) |
|
|
v = self.dequant_v(v) |
|
|
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
|
|
|
if attn_mask is not None: |
|
|
if attn_mask.dtype == torch.bool: |
|
|
attn_output_weights.masked_fill_(attn_mask, float('-inf')) |
|
|
else: |
|
|
attn_output_weights += attn_mask |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
|
attn_output_weights = attn_output_weights.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2), |
|
|
float('-inf'), |
|
|
) |
|
|
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
|
|
attn_output_weights = nnF.softmax( |
|
|
attn_output_weights, dim=-1) |
|
|
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training) |
|
|
|
|
|
attn_output = torch.bmm(attn_output_weights, v) |
|
|
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] |
|
|
if self.batch_first: |
|
|
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) |
|
|
else: |
|
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) |
|
|
|
|
|
|
|
|
attn_output = self.quant_attn_output(attn_output) |
|
|
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
attn_output_weights = self.quant_attn_output_weights(attn_output_weights) |
|
|
|
|
|
if need_weights: |
|
|
|
|
|
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
|
if average_attn_weights: |
|
|
attn_output_weights = attn_output_weights.mean(dim=1) |
|
|
return attn_output, attn_output_weights |
|
|
else: |
|
|
return attn_output, None |
|
|
|