ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from typing import Optional
from torch import Tensor
import numpy as np
import torch
import torch.nn as nn
class MultiheadAttention(nn.Module):
def __init__(self,
query_channels:int,
key_channels:int,
value_channels:int,
total_hidden_channels:int,
out_channels:int,
num_heads: int = 8
) -> None:
super().__init__()
assert total_hidden_channels % num_heads == 0, f'hidden channel size({total_hidden_channels}) must be factorized by the number of heads({num_heads})'
self.num_heads:int = num_heads
self.hidden_channels:int = total_hidden_channels // num_heads
self.projection_query:nn.Module = nn.Conv1d(in_channels=query_channels, out_channels=total_hidden_channels, kernel_size=1)
self.projection_key:nn.Module = nn.Conv1d(in_channels=key_channels, out_channels=total_hidden_channels, kernel_size=1)
self.projection_value:nn.Module = nn.Conv1d(in_channels=value_channels, out_channels=total_hidden_channels, kernel_size=1)
self.projection_out:nn.Module = nn.Conv1d(in_channels=total_hidden_channels, out_channels=out_channels, kernel_size=1)
def forward(self,
queries:Tensor, #torch.float32 [batch, query_channels, number_of_queries]
keys:Tensor, #torch.float32 [batch, key_channels, number_of_keys/values]
values:Tensor, #torch.float32 [batch, value_channels, number_of_keys/values]
mask: Optional[Tensor] = None #torch.float32 [batch, number_of_keys/values, number_of_queries]
) -> Tensor: #torch.float32 [batch, out_channels, number_of_queries]
batch_size:int = queries.shape[0]
number_of_queries:int = queries.shape[-1]
number_of_keys_and_values:int = keys.shape[-1]
assert(keys.shape[-1] == values.shape[-1]), f'number of keys({keys.shape[-1]}) and number of values({values.shape[-1]}) must be the same'
#[batch, num_heads, hidden_channels, number_of_queries]
queries = self.projection_query(queries).view(batch_size, self.num_heads, -1, number_of_queries)
#[batch, num_heads, hidden_channels, number_of_keys/values]
keys = self.projection_key(keys).view(batch_size, self.num_heads, -1, number_of_keys_and_values)
values = self.projection_key(values).view(batch_size, self.num_heads, -1, number_of_keys_and_values)
#[batch, num_heads, number_of_keys/values,number_of_queries] martix mul of [number_of_keys/values, hidden_channels], [hidden_channels,number_of_queries]
score:Tensor = torch.matmul(keys.transpose(2, 3), queries) * (self.hidden_channels ** -0.5)
if mask is not None:
score.masked_fill_(~mask[:, None, :, :1].to(torch.bool), -np.inf)
#[batch, num_heads, number_of_keys/values,number_of_queries]
weights:Tensor = torch.softmax(score, dim=2)
#[batch, out_channels, number_of_queries]
out = self.projection_out(torch.matmul(values, weights).view(batch_size, -1, number_of_queries))
if mask is not None:
out = out * mask[:, :1]
return out