Spaces:
Paused
Paused
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| class MultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| direction_input_dim: int, | |
| conditioning_input_dim: int, | |
| latent_dim: int, | |
| num_heads: int, | |
| ): | |
| """ | |
| Multi-Head Attention module. | |
| Args: | |
| direction_input_dim (int): The input dimension of the directional input. | |
| conditioning_input_dim (int): The input dimension of the conditioning input. | |
| latent_dim (int): The latent dimension of the module. | |
| num_heads (int): The number of heads to use in the attention mechanism. | |
| """ | |
| super().__init__() | |
| assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads" | |
| self.num_heads = num_heads | |
| self.head_dim = latent_dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.query = nn.Linear(direction_input_dim, latent_dim) | |
| self.key = nn.Linear(conditioning_input_dim, latent_dim) | |
| self.value = nn.Linear(conditioning_input_dim, latent_dim) | |
| self.fc_out = nn.Linear(latent_dim, latent_dim) | |
| def forward( | |
| self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass of the Multi-Head Attention module. | |
| Args: | |
| query (torch.Tensor): The directional input tensor. | |
| key (torch.Tensor): The conditioning input tensor for the keys. | |
| value (torch.Tensor): The conditioning input tensor for the values. | |
| Returns: | |
| torch.Tensor: The output tensor of the Multi-Head Attention module. | |
| """ | |
| batch_size = query.size(0) | |
| Q = ( | |
| self.query(query) | |
| .view(batch_size, -1, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| K = ( | |
| self.key(key) | |
| .view(batch_size, -1, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| V = ( | |
| self.value(value) | |
| .view(batch_size, -1, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| attention = ( | |
| torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale | |
| ) | |
| attention = torch.softmax(attention, dim=-1) | |
| out = torch.einsum("bnqh,bnhv->bnqv", [attention, V]) | |
| out = ( | |
| out.transpose(1, 2) | |
| .contiguous() | |
| .view(batch_size, -1, self.num_heads * self.head_dim) | |
| ) | |
| out = self.fc_out(out).squeeze(1) | |
| return out | |
| class AttentionLayer(nn.Module): | |
| def __init__( | |
| self, | |
| direction_input_dim: int, | |
| conditioning_input_dim: int, | |
| latent_dim: int, | |
| num_heads: int, | |
| ): | |
| """ | |
| Attention Layer module. | |
| Args: | |
| direction_input_dim (int): The input dimension of the directional input. | |
| conditioning_input_dim (int): The input dimension of the conditioning input. | |
| latent_dim (int): The latent dimension of the module. | |
| num_heads (int): The number of heads to use in the attention mechanism. | |
| """ | |
| super().__init__() | |
| self.mha = MultiHeadAttention( | |
| direction_input_dim, conditioning_input_dim, latent_dim, num_heads | |
| ) | |
| self.norm1 = nn.LayerNorm(latent_dim) | |
| self.norm2 = nn.LayerNorm(latent_dim) | |
| self.fc = nn.Sequential( | |
| nn.Linear(latent_dim, latent_dim), | |
| nn.ReLU(), | |
| nn.Linear(latent_dim, latent_dim), | |
| ) | |
| def forward( | |
| self, directional_input: torch.Tensor, conditioning_input: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass of the Attention Layer module. | |
| Args: | |
| directional_input (torch.Tensor): The directional input tensor. | |
| conditioning_input (torch.Tensor): The conditioning input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor of the Attention Layer module. | |
| """ | |
| attn_output = self.mha( | |
| directional_input, conditioning_input, conditioning_input | |
| ) | |
| out1 = self.norm1(attn_output + directional_input) | |
| fc_output = self.fc(out1) | |
| out2 = self.norm2(fc_output + out1) | |
| return out2 | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| conditioning_input_dim: int, | |
| hidden_features: int, | |
| num_heads: int, | |
| num_layers: int, | |
| out_activation: Optional[nn.Module], | |
| ): | |
| """ | |
| Decoder module. | |
| Args: | |
| in_dim (int): The input dimension of the module. | |
| conditioning_input_dim (int): The input dimension of the conditioning input. | |
| hidden_features (int): The number of hidden features in the module. | |
| num_heads (int): The number of heads to use in the attention mechanism. | |
| num_layers (int): The number of layers in the module. | |
| out_activation (nn.Module): The activation function to use on the output tensor. | |
| """ | |
| super().__init__() | |
| self.residual_projection = nn.Linear( | |
| in_dim, hidden_features | |
| ) # projection for residual connection | |
| self.layers = nn.ModuleList( | |
| [ | |
| AttentionLayer( | |
| hidden_features, conditioning_input_dim, hidden_features, num_heads | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.fc = nn.Linear(hidden_features, 3) # 3 for RGB | |
| self.out_activation = out_activation | |
| def forward( | |
| self, x: torch.Tensor, conditioning_input: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass of the Decoder module. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| conditioning_input (torch.Tensor): The conditioning input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor of the Decoder module. | |
| """ | |
| x = self.residual_projection(x) | |
| for layer in self.layers: | |
| x = layer(x, conditioning_input) | |
| x = self.fc(x) | |
| if self.out_activation is not None: | |
| x = self.out_activation(x) | |
| return x | |