Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from xformers.components.attention import Attention, AttentionConfig, register_attention | |
| class PoolingAttentionConfig(AttentionConfig): | |
| pool_size: int # dimension of the input sequence | |
| stride: Optional[int] # dimension of the internal space | |
| padding: Optional[int] | |
| class Pooling(Attention): | |
| def __init__( | |
| self, | |
| pool_size: int = 3, | |
| stride: int = 1, | |
| padding: Optional[int] = None, | |
| *_, | |
| **__, | |
| ): | |
| """ | |
| Pooling token mixing mechanism, as proposed in | |
| `Metaformer is actually what you need for vision`_, Yu et al (2021). | |
| The original notation is kept as is. | |
| .. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf | |
| """ | |
| super().__init__() | |
| padding = padding if padding is not None else pool_size // 2 | |
| self.pool = nn.AvgPool2d( | |
| pool_size, | |
| stride=stride, | |
| padding=pool_size // 2, | |
| count_include_pad=False, | |
| ) | |
| # MHA related flags: | |
| # kq need to have the same dimension | |
| self.requires_same_k_q_dimensions = False | |
| # This attention does not support attention masks | |
| self.supports_attention_mask = False | |
| # This "attention" (token mixing) skips the multihead attention altogether | |
| self.requires_skip_multi_head = True | |
| self.requires_input_projection = False | |
| # This operator does not really handle q,k,v | |
| self.requires_same_k_q_dimensions = True | |
| # This attention requires the 2d structure out of the context, | |
| # implictly assumed to be a squared length | |
| self.requires_squared_context = True | |
| def forward(self, q: torch.Tensor, *_, **__): | |
| # Expose the 2D token structure | |
| B, HW, C = q.shape | |
| H = int(math.sqrt(HW)) | |
| assert H * H == HW | |
| q = q.transpose(-2, -1).reshape(B, C, H, H) | |
| # 2D pool | |
| x_pool = self.pool(q) - q # compensate for the residual path | |
| # Get back to B HW C | |
| return x_pool.flatten(2, 3).transpose(-2, -1) | |