| from typing import Optional |
| from torch import Tensor |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| class MultiheadAttention(nn.Module): |
| def __init__(self, |
| query_channels:int, |
| key_channels:int, |
| value_channels:int, |
| total_hidden_channels:int, |
| out_channels:int, |
| num_heads: int = 8 |
| ) -> None: |
| super().__init__() |
| assert total_hidden_channels % num_heads == 0, f'hidden channel size({total_hidden_channels}) must be factorized by the number of heads({num_heads})' |
| self.num_heads:int = num_heads |
| self.hidden_channels:int = total_hidden_channels // num_heads |
| self.projection_query:nn.Module = nn.Conv1d(in_channels=query_channels, out_channels=total_hidden_channels, kernel_size=1) |
| self.projection_key:nn.Module = nn.Conv1d(in_channels=key_channels, out_channels=total_hidden_channels, kernel_size=1) |
| self.projection_value:nn.Module = nn.Conv1d(in_channels=value_channels, out_channels=total_hidden_channels, kernel_size=1) |
| self.projection_out:nn.Module = nn.Conv1d(in_channels=total_hidden_channels, out_channels=out_channels, kernel_size=1) |
|
|
| def forward(self, |
| queries:Tensor, |
| keys:Tensor, |
| values:Tensor, |
| mask: Optional[Tensor] = None |
| ) -> Tensor: |
| batch_size:int = queries.shape[0] |
| number_of_queries:int = queries.shape[-1] |
| number_of_keys_and_values:int = keys.shape[-1] |
| assert(keys.shape[-1] == values.shape[-1]), f'number of keys({keys.shape[-1]}) and number of values({values.shape[-1]}) must be the same' |
| |
| queries = self.projection_query(queries).view(batch_size, self.num_heads, -1, number_of_queries) |
| |
| keys = self.projection_key(keys).view(batch_size, self.num_heads, -1, number_of_keys_and_values) |
| values = self.projection_key(values).view(batch_size, self.num_heads, -1, number_of_keys_and_values) |
| |
| score:Tensor = torch.matmul(keys.transpose(2, 3), queries) * (self.hidden_channels ** -0.5) |
| if mask is not None: |
| score.masked_fill_(~mask[:, None, :, :1].to(torch.bool), -np.inf) |
| |
| weights:Tensor = torch.softmax(score, dim=2) |
| |
| out = self.projection_out(torch.matmul(values, weights).view(batch_size, -1, number_of_queries)) |
| if mask is not None: |
| out = out * mask[:, :1] |
| return out |