| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from modules.general.utils import Conv1d, normalization, zero_module |
| | from .basic import UNetBlock |
| |
|
| |
|
| | class AttentionBlock(UNetBlock): |
| | r"""A spatial transformer encoder block that allows spatial positions to attend |
| | to each other. Reference from `latent diffusion repo |
| | <https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_. |
| | |
| | Args: |
| | channels: Number of channels in the input. |
| | num_head_channels: Number of channels per attention head. |
| | num_heads: Number of attention heads. Overrides ``num_head_channels`` if set. |
| | encoder_channels: Number of channels in the encoder output for cross-attention. |
| | If ``None``, then self-attention is performed. |
| | use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set. |
| | dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images. |
| | h_dim: The dimension of the height, would be applied if ``dims`` is 2. |
| | encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2. |
| | p_dropout: Dropout probability. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | num_head_channels: int = 32, |
| | num_heads: int = -1, |
| | encoder_channels: int = None, |
| | use_self_attention: bool = False, |
| | dims: int = 1, |
| | h_dim: int = 100, |
| | encoder_hdim: int = 384, |
| | p_dropout: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | self.channels = channels |
| | self.p_dropout = p_dropout |
| | self.dims = dims |
| |
|
| | if dims == 1: |
| | self.channels = channels |
| | elif dims == 2: |
| | |
| | |
| | self.channels = channels * h_dim |
| | else: |
| | raise ValueError(f"invalid number of dimensions: {dims}") |
| |
|
| | if num_head_channels == -1: |
| | assert ( |
| | self.channels % num_heads == 0 |
| | ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
| | self.num_heads = num_heads |
| | self.num_head_channels = self.channels // num_heads |
| | else: |
| | assert ( |
| | self.channels % num_head_channels == 0 |
| | ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
| | self.num_heads = self.channels // num_head_channels |
| | self.num_head_channels = num_head_channels |
| |
|
| | if encoder_channels is not None: |
| | self.use_self_attention = use_self_attention |
| |
|
| | if dims == 1: |
| | self.encoder_channels = encoder_channels |
| | elif dims == 2: |
| | self.encoder_channels = encoder_channels * encoder_hdim |
| | else: |
| | raise ValueError(f"invalid number of dimensions: {dims}") |
| |
|
| | if use_self_attention: |
| | self.self_attention = BasicAttentionBlock( |
| | self.channels, |
| | self.num_head_channels, |
| | self.num_heads, |
| | p_dropout=self.p_dropout, |
| | ) |
| | self.cross_attention = BasicAttentionBlock( |
| | self.channels, |
| | self.num_head_channels, |
| | self.num_heads, |
| | self.encoder_channels, |
| | p_dropout=self.p_dropout, |
| | ) |
| | else: |
| | self.encoder_channels = None |
| | self.self_attention = BasicAttentionBlock( |
| | self.channels, |
| | self.num_head_channels, |
| | self.num_heads, |
| | p_dropout=self.p_dropout, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None): |
| | r""" |
| | Args: |
| | x: input tensor with shape [B x ``channels`` x ...] |
| | encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed. |
| | |
| | Returns: |
| | output tensor with shape [B x ``channels`` x ...] |
| | """ |
| | shape = x.size() |
| | x = x.reshape(shape[0], self.channels, -1).contiguous() |
| |
|
| | if self.encoder_channels is None: |
| | assert ( |
| | encoder_output is None |
| | ), "encoder_output must be None for self-attention." |
| | h = self.self_attention(x) |
| |
|
| | else: |
| | assert ( |
| | encoder_output is not None |
| | ), "encoder_output must be given for cross-attention." |
| | encoder_output = encoder_output.reshape( |
| | shape[0], self.encoder_channels, -1 |
| | ).contiguous() |
| |
|
| | if self.use_self_attention: |
| | x = self.self_attention(x) |
| | h = self.cross_attention(x, encoder_output) |
| |
|
| | return h.reshape(*shape).contiguous() |
| |
|
| |
|
| | class BasicAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | channels: int, |
| | num_head_channels: int = 32, |
| | num_heads: int = -1, |
| | context_channels: int = None, |
| | p_dropout: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | self.channels = channels |
| | self.p_dropout = p_dropout |
| | self.context_channels = context_channels |
| |
|
| | if num_head_channels == -1: |
| | assert ( |
| | self.channels % num_heads == 0 |
| | ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
| | self.num_heads = num_heads |
| | self.num_head_channels = self.channels // num_heads |
| | else: |
| | assert ( |
| | self.channels % num_head_channels == 0 |
| | ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
| | self.num_heads = self.channels // num_head_channels |
| | self.num_head_channels = num_head_channels |
| |
|
| | if context_channels is not None: |
| | self.to_q = nn.Sequential( |
| | normalization(self.channels), |
| | Conv1d(self.channels, self.channels, 1), |
| | ) |
| | self.to_kv = Conv1d(context_channels, 2 * self.channels, 1) |
| | else: |
| | self.to_qkv = nn.Sequential( |
| | normalization(self.channels), |
| | Conv1d(self.channels, 3 * self.channels, 1), |
| | ) |
| |
|
| | self.linear = Conv1d(self.channels, self.channels) |
| |
|
| | self.proj_out = nn.Sequential( |
| | normalization(self.channels), |
| | Conv1d(self.channels, self.channels, 1), |
| | nn.GELU(), |
| | nn.Dropout(p=self.p_dropout), |
| | zero_module(Conv1d(self.channels, self.channels, 1)), |
| | ) |
| |
|
| | def forward(self, q: torch.Tensor, kv: torch.Tensor = None): |
| | r""" |
| | Args: |
| | q: input tensor with shape [B, ``channels``, L] |
| | kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed. |
| | |
| | Returns: |
| | output tensor with shape [B, ``channels``, L] |
| | """ |
| | N, C, L = q.size() |
| |
|
| | if self.context_channels is not None: |
| | assert kv is not None, "kv must be given for cross-attention." |
| |
|
| | q = ( |
| | self.to_q(q) |
| | .reshape(self.num_heads, self.num_head_channels, -1) |
| | .transpose(-1, -2) |
| | .contiguous() |
| | ) |
| | kv = ( |
| | self.to_kv(kv) |
| | .reshape(2, self.num_heads, self.num_head_channels, -1) |
| | .transpose(-1, -2) |
| | .chunk(2) |
| | ) |
| | k, v = ( |
| | kv[0].squeeze(0).contiguous(), |
| | kv[1].squeeze(0).contiguous(), |
| | ) |
| |
|
| | else: |
| | qkv = ( |
| | self.to_qkv(q) |
| | .reshape(3, self.num_heads, self.num_head_channels, -1) |
| | .transpose(-1, -2) |
| | .chunk(3) |
| | ) |
| | q, k, v = ( |
| | qkv[0].squeeze(0).contiguous(), |
| | qkv[1].squeeze(0).contiguous(), |
| | qkv[2].squeeze(0).contiguous(), |
| | ) |
| |
|
| | h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose( |
| | -1, -2 |
| | ) |
| | h = h.reshape(N, -1, L).contiguous() |
| | h = self.linear(h) |
| |
|
| | x = q + h |
| | h = self.proj_out(x) |
| |
|
| | return x + h |
| |
|