| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from modules.general.utils import Conv1d, normalization, zero_module |
| from .basic import UNetBlock |
|
|
|
|
| class AttentionBlock(UNetBlock): |
| r"""A spatial transformer encoder block that allows spatial positions to attend |
| to each other. Reference from `latent diffusion repo |
| <https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_. |
| |
| Args: |
| channels: Number of channels in the input. |
| num_head_channels: Number of channels per attention head. |
| num_heads: Number of attention heads. Overrides ``num_head_channels`` if set. |
| encoder_channels: Number of channels in the encoder output for cross-attention. |
| If ``None``, then self-attention is performed. |
| use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set. |
| dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images. |
| h_dim: The dimension of the height, would be applied if ``dims`` is 2. |
| encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2. |
| p_dropout: Dropout probability. |
| """ |
|
|
| def __init__( |
| self, |
| channels: int, |
| num_head_channels: int = 32, |
| num_heads: int = -1, |
| encoder_channels: int = None, |
| use_self_attention: bool = False, |
| dims: int = 1, |
| h_dim: int = 100, |
| encoder_hdim: int = 384, |
| p_dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.channels = channels |
| self.p_dropout = p_dropout |
| self.dims = dims |
|
|
| if dims == 1: |
| self.channels = channels |
| elif dims == 2: |
| |
| |
| self.channels = channels * h_dim |
| else: |
| raise ValueError(f"invalid number of dimensions: {dims}") |
|
|
| if num_head_channels == -1: |
| assert ( |
| self.channels % num_heads == 0 |
| ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
| self.num_heads = num_heads |
| self.num_head_channels = self.channels // num_heads |
| else: |
| assert ( |
| self.channels % num_head_channels == 0 |
| ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = self.channels // num_head_channels |
| self.num_head_channels = num_head_channels |
|
|
| if encoder_channels is not None: |
| self.use_self_attention = use_self_attention |
|
|
| if dims == 1: |
| self.encoder_channels = encoder_channels |
| elif dims == 2: |
| self.encoder_channels = encoder_channels * encoder_hdim |
| else: |
| raise ValueError(f"invalid number of dimensions: {dims}") |
|
|
| if use_self_attention: |
| self.self_attention = BasicAttentionBlock( |
| self.channels, |
| self.num_head_channels, |
| self.num_heads, |
| p_dropout=self.p_dropout, |
| ) |
| self.cross_attention = BasicAttentionBlock( |
| self.channels, |
| self.num_head_channels, |
| self.num_heads, |
| self.encoder_channels, |
| p_dropout=self.p_dropout, |
| ) |
| else: |
| self.encoder_channels = None |
| self.self_attention = BasicAttentionBlock( |
| self.channels, |
| self.num_head_channels, |
| self.num_heads, |
| p_dropout=self.p_dropout, |
| ) |
|
|
| def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None): |
| r""" |
| Args: |
| x: input tensor with shape [B x ``channels`` x ...] |
| encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed. |
| |
| Returns: |
| output tensor with shape [B x ``channels`` x ...] |
| """ |
| shape = x.size() |
| x = x.reshape(shape[0], self.channels, -1).contiguous() |
|
|
| if self.encoder_channels is None: |
| assert ( |
| encoder_output is None |
| ), "encoder_output must be None for self-attention." |
| h = self.self_attention(x) |
|
|
| else: |
| assert ( |
| encoder_output is not None |
| ), "encoder_output must be given for cross-attention." |
| encoder_output = encoder_output.reshape( |
| shape[0], self.encoder_channels, -1 |
| ).contiguous() |
|
|
| if self.use_self_attention: |
| x = self.self_attention(x) |
| h = self.cross_attention(x, encoder_output) |
|
|
| return h.reshape(*shape).contiguous() |
|
|
|
|
| class BasicAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| num_head_channels: int = 32, |
| num_heads: int = -1, |
| context_channels: int = None, |
| p_dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.channels = channels |
| self.p_dropout = p_dropout |
| self.context_channels = context_channels |
|
|
| if num_head_channels == -1: |
| assert ( |
| self.channels % num_heads == 0 |
| ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
| self.num_heads = num_heads |
| self.num_head_channels = self.channels // num_heads |
| else: |
| assert ( |
| self.channels % num_head_channels == 0 |
| ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = self.channels // num_head_channels |
| self.num_head_channels = num_head_channels |
|
|
| if context_channels is not None: |
| self.to_q = nn.Sequential( |
| normalization(self.channels), |
| Conv1d(self.channels, self.channels, 1), |
| ) |
| self.to_kv = Conv1d(context_channels, 2 * self.channels, 1) |
| else: |
| self.to_qkv = nn.Sequential( |
| normalization(self.channels), |
| Conv1d(self.channels, 3 * self.channels, 1), |
| ) |
|
|
| self.linear = Conv1d(self.channels, self.channels) |
|
|
| self.proj_out = nn.Sequential( |
| normalization(self.channels), |
| Conv1d(self.channels, self.channels, 1), |
| nn.GELU(), |
| nn.Dropout(p=self.p_dropout), |
| zero_module(Conv1d(self.channels, self.channels, 1)), |
| ) |
|
|
| def forward(self, q: torch.Tensor, kv: torch.Tensor = None): |
| r""" |
| Args: |
| q: input tensor with shape [B, ``channels``, L] |
| kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed. |
| |
| Returns: |
| output tensor with shape [B, ``channels``, L] |
| """ |
| N, C, L = q.size() |
|
|
| if self.context_channels is not None: |
| assert kv is not None, "kv must be given for cross-attention." |
|
|
| q = ( |
| self.to_q(q) |
| .reshape(self.num_heads, self.num_head_channels, -1) |
| .transpose(-1, -2) |
| .contiguous() |
| ) |
| kv = ( |
| self.to_kv(kv) |
| .reshape(2, self.num_heads, self.num_head_channels, -1) |
| .transpose(-1, -2) |
| .chunk(2) |
| ) |
| k, v = ( |
| kv[0].squeeze(0).contiguous(), |
| kv[1].squeeze(0).contiguous(), |
| ) |
|
|
| else: |
| qkv = ( |
| self.to_qkv(q) |
| .reshape(3, self.num_heads, self.num_head_channels, -1) |
| .transpose(-1, -2) |
| .chunk(3) |
| ) |
| q, k, v = ( |
| qkv[0].squeeze(0).contiguous(), |
| qkv[1].squeeze(0).contiguous(), |
| qkv[2].squeeze(0).contiguous(), |
| ) |
|
|
| h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose( |
| -1, -2 |
| ) |
| h = h.reshape(N, -1, L).contiguous() |
| h = self.linear(h) |
|
|
| x = q + h |
| h = self.proj_out(x) |
|
|
| return x + h |
|
|