Spaces:
Runtime error
Runtime error
| """ | |
| Perceiver code is based on Aurora: https://github.com/microsoft/aurora/blob/main/aurora/model/perceiver.py | |
| Some conventions for notation: | |
| B - Batch | |
| T - Time | |
| H - Height (pixel space) | |
| W - Width (pixel space) | |
| HT - Height (token space) | |
| WT - Width (token space) | |
| ST - Sequence (token space) | |
| C - Input channels | |
| D - Model (embedding) dimension | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from timm.models.layers import trunc_normal_ | |
| class PatchEmbed3D(nn.Module): | |
| """Timeseries Image to Patch Embedding""" | |
| def __init__( | |
| self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, time_dim=2 | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.time_dim = time_dim | |
| self.proj = nn.Conv2d( | |
| in_chans * time_dim, | |
| embed_dim, | |
| kernel_size=(patch_size, patch_size), | |
| stride=(patch_size, patch_size), | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor of shape (B, C, T, H, W) | |
| Returns: | |
| Tensor of shape (B, ST, D) | |
| """ | |
| B, C, T, H, W = x.shape | |
| x = self.proj(x.flatten(1, 2)) # (B, C, T, H, W) -> (B, D, HT, WT) | |
| x = rearrange(x, "B D HT WT -> B (HT WT) D") # (B, N, D) | |
| return x | |
| class LinearEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| img_size=224, | |
| patch_size=16, | |
| in_chans=3, | |
| time_dim=2, | |
| embed_dim=768, | |
| drop_rate=0.0, | |
| ): | |
| super().__init__() | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.patch_embed = PatchEmbed3D( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| time_dim=time_dim, | |
| ) | |
| self._generate_position_encoding(img_size, patch_size, embed_dim) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| def _generate_position_encoding(self, img_size, patch_size, embed_dim): | |
| """ | |
| Generates a positional encoding signal for the model. The generated | |
| positional encoding signal is stored as a buffer (`self.fourier_signal`). | |
| Args: | |
| img_size (int): The size of the input image. | |
| patch_size (int): The size of each patch in the image. | |
| embed_dim (int): The embedding dimension of the model. | |
| Returns: | |
| None. | |
| """ | |
| # Generate signal of shape (C, H, W) | |
| x = torch.linspace(0.0, 1.0, img_size // patch_size) | |
| y = torch.linspace(0.0, 1.0, img_size // patch_size) | |
| x, y = torch.meshgrid(x, y, indexing="xy") | |
| fourier_signal = [] | |
| frequencies = torch.linspace(1, (img_size // patch_size) / 2.0, embed_dim // 4) | |
| for f in frequencies: | |
| fourier_signal.extend( | |
| [ | |
| torch.cos(2.0 * torch.pi * f * x), | |
| torch.sin(2.0 * torch.pi * f * x), | |
| torch.cos(2.0 * torch.pi * f * y), | |
| torch.sin(2.0 * torch.pi * f * y), | |
| ] | |
| ) | |
| fourier_signal = torch.stack(fourier_signal, dim=2) | |
| fourier_signal = rearrange(fourier_signal, "h w c -> 1 (h w) c") | |
| self.register_buffer("pos_embed", fourier_signal) | |
| def forward(self, x, dt): | |
| """ | |
| Args: | |
| x: Tensor of shape (B, C, T, H, W). | |
| dt: Tensor of shape (B, T). However it is not used. | |
| Returns: | |
| Tensor of shape (B, ST, D) | |
| """ | |
| x = self.patch_embed(x) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| return x | |
| class LinearDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int, | |
| out_chans: int, | |
| embed_dim: int, | |
| ): | |
| """ | |
| Args: | |
| patch_size: patch size | |
| in_chans: number of iput channels | |
| embed_dim: embedding dimension | |
| """ | |
| super().__init__() | |
| self.unembed = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=embed_dim, | |
| out_channels=(patch_size**2) * out_chans, | |
| kernel_size=1, | |
| ), | |
| nn.PixelShuffle(patch_size), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Tensor of shape (B, L, D). For ensembles, we have implicitly B = (B E). | |
| Returns: | |
| Tensor of shape (B C H W). | |
| Here | |
| - C equals num_queries | |
| - H == W == sqrt(L) x patch_size | |
| """ | |
| # Reshape the tokens to 2d token space: (B, C, H_token, W_token) | |
| _, L, _ = x.shape | |
| H_token = W_token = int(L**0.5) | |
| x = rearrange(x, "B (H W) D -> B D H W", H=H_token, W=W_token) | |
| # Unembed the tokens. Convolution + pixel shuffle. | |
| x = self.unembed(x) | |
| return x | |
| class MLP(nn.Module): | |
| """A simple one-hidden-layer MLP.""" | |
| def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None: | |
| """Initialise. | |
| Args: | |
| dim (int): Input dimensionality. | |
| hidden_features (int): Width of the hidden layer. | |
| dropout (float, optional): Drop-out rate. Defaults to no drop-out. | |
| """ | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_features), | |
| nn.GELU(), | |
| nn.Linear(hidden_features, dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Run the MLP.""" | |
| return self.net(x) | |
| class PerceiverAttention(nn.Module): | |
| """Cross attention module from the Perceiver architecture.""" | |
| def __init__( | |
| self, | |
| latent_dim: int, | |
| context_dim: int, | |
| head_dim: int = 64, | |
| num_heads: int = 8, | |
| ) -> None: | |
| """Initialise. | |
| Args: | |
| latent_dim (int): Dimensionality of the latent features given as input. | |
| context_dim (int): Dimensionality of the context features also given as input. | |
| head_dim (int): Attention head dimensionality. | |
| num_heads (int): Number of heads. | |
| """ | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.inner_dim = head_dim * num_heads | |
| self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False) | |
| self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False) | |
| self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False) | |
| def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| """Run the cross-attention module. | |
| Args: | |
| latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, Latent_D)` | |
| where typically `L1 < L2` and `Latent_D <= Context_D`. `Latent_D` is equal to | |
| `self.latent_dim`. | |
| x (:class:`torch.Tensor`): Context features of shape `(B, L2, Context_D)`. | |
| Returns: | |
| :class:`torch.Tensor`: Latent values of shape `(B, L1, Latent_D)`. | |
| """ | |
| h = self.num_heads | |
| q = self.to_q(latents) # (B, L1, D2) to (B, L1, D) | |
| k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D) | |
| q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v)) | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| out = rearrange(out, "B H L1 D -> B L1 (H D)") # (B, L1, D) | |
| return self.to_out(out) # (B, L1, Latent_D) | |
| class PerceiverResampler(nn.Module): | |
| """Perceiver Resampler module from the Flamingo paper.""" | |
| def __init__( | |
| self, | |
| latent_dim: int, | |
| context_dim: int, | |
| depth: int = 1, | |
| head_dim: int = 64, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| drop: float = 0.0, | |
| residual_latent: bool = True, | |
| ln_eps: float = 1e-5, | |
| ) -> None: | |
| """Initialise. | |
| Args: | |
| latent_dim (int): Dimensionality of the latent features given as input. | |
| context_dim (int): Dimensionality of the context features also given as input. | |
| depth (int, optional): Number of attention layers. | |
| head_dim (int, optional): Attention head dimensionality. Defaults to `64`. | |
| num_heads (int, optional): Number of heads. Defaults to `16` | |
| mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the | |
| input for all MLPs. Defaults to `4.0`. | |
| drop (float, optional): Drop-out rate. Defaults to no drop-out. | |
| residual_latent (bool, optional): Use residual attention w.r.t. the latent features. | |
| Defaults to `True`. | |
| ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to | |
| `1e-5`. | |
| """ | |
| super().__init__() | |
| self.residual_latent = residual_latent | |
| self.layers = nn.ModuleList([]) | |
| mlp_hidden_dim = int(latent_dim * mlp_ratio) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| PerceiverAttention( | |
| latent_dim=latent_dim, | |
| context_dim=context_dim, | |
| head_dim=head_dim, | |
| num_heads=num_heads, | |
| ), | |
| MLP( | |
| dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop | |
| ), | |
| nn.LayerNorm(latent_dim, eps=ln_eps), | |
| nn.LayerNorm(latent_dim, eps=ln_eps), | |
| ] | |
| ) | |
| ) | |
| def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| """Run the module. | |
| Args: | |
| latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, D1)`. | |
| x (:class:`torch.Tensor`): Context features of shape `(B, L2, D1)`. | |
| Returns: | |
| torch.Tensor: Latent features of shape `(B, L1, D1)`. | |
| """ | |
| for attn, ff, ln1, ln2 in self.layers: | |
| # We use post-res-norm like in Swin v2 and most Transformer architectures these days. | |
| # This empirically works better than the pre-norm used in the original Perceiver. | |
| attn_out = ln1(attn(latents, x)) | |
| # HuggingFace suggests using non-residual attention in Perceiver might work better when | |
| # the semantics of the query and the output are different: | |
| # | |
| # https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398 | |
| # | |
| latents = attn_out + latents if self.residual_latent else attn_out | |
| latents = ln2(ff(latents)) + latents | |
| return latents | |
| class PerceiverChannelEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| in_chans: int, | |
| img_size: int, | |
| patch_size: int, | |
| time_dim: int, | |
| num_queries: int, | |
| embed_dim: int, | |
| drop_rate: float, | |
| ): | |
| super().__init__() | |
| if embed_dim % 2 != 0: | |
| raise ValueError( | |
| f"Temporal embeddings require `embed_dim` to be even. Currently we have {embed_dim}." | |
| ) | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.num_queries = num_queries | |
| self.embed_dim = embed_dim | |
| self.proj = nn.Conv2d( | |
| in_channels=in_chans * time_dim, | |
| out_channels=in_chans * embed_dim, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| groups=in_chans, | |
| ) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_patches)) | |
| trunc_normal_(self.pos_embed, std=0.02) | |
| self.latent_queries = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) | |
| trunc_normal_(self.latent_queries, std=0.02) | |
| self.perceiver = PerceiverResampler( | |
| latent_dim=embed_dim, | |
| context_dim=embed_dim, | |
| depth=1, | |
| head_dim=embed_dim // 16, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| drop=0.0, | |
| residual_latent=False, | |
| ln_eps=1e-5, | |
| ) | |
| self.latent_aggregation = nn.Linear(num_queries * embed_dim, embed_dim) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| def forward(self, x, dt): | |
| """ | |
| Args: | |
| x: Tensor of shape (B, C, T, H, W) | |
| dt: Tensor of shape (B, T) identifying time deltas. | |
| Returns: | |
| Tensor of shape (B, ST, D) | |
| """ | |
| B, C, T, H, W = x.shape | |
| x = rearrange(x, "B C T H W -> B (C T) H W") | |
| x = self.proj(x) # B (C T) H W -> B (C D) HT WT | |
| x = x.flatten(2, 3) # B (C D) ST | |
| ST = x.shape[2] | |
| assert ST == self.num_patches | |
| x = rearrange(x, "B (C D) ST -> (B C) D ST", B=B, ST=ST, C=C, D=self.embed_dim) | |
| x = x + self.pos_embed | |
| x = rearrange(x, "(B C) D ST -> (B ST) C D", B=B, ST=ST, C=C, D=self.embed_dim) | |
| # ((B ST) NQ D), ((B ST) C D) -> ((B ST) NQ D) | |
| x = self.perceiver(self.latent_queries.expand(B * ST, -1, -1), x) | |
| x = rearrange( | |
| x, | |
| "(B ST) NQ D -> B ST (NQ D)", | |
| B=B, | |
| ST=self.num_patches, | |
| NQ=self.num_queries, | |
| D=self.embed_dim, | |
| ) | |
| x = self.latent_aggregation(x) # B ST (NQ D) -> B ST D' | |
| assert x.shape[1] == self.num_patches | |
| assert x.shape[2] == self.embed_dim | |
| x = self.pos_drop(x) | |
| return x | |
| class PerceiverDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| patch_size: int, | |
| out_chans: int, | |
| ): | |
| """ | |
| Args: | |
| embed_dim: embedding dimension | |
| patch_size: patch size | |
| out_chans: number of output channels. This determines the number of latent queries. | |
| drop_rate: dropout rate | |
| """ | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.patch_size = patch_size | |
| self.out_chans = out_chans | |
| self.latent_queries = nn.Parameter(torch.zeros(1, out_chans, embed_dim)) | |
| trunc_normal_(self.latent_queries, std=0.02) | |
| self.perceiver = PerceiverResampler( | |
| latent_dim=embed_dim, | |
| context_dim=embed_dim, | |
| depth=1, | |
| head_dim=embed_dim // 16, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| drop=0.0, | |
| residual_latent=False, | |
| ln_eps=1e-5, | |
| ) | |
| self.proj = nn.Conv2d( | |
| in_channels=out_chans * embed_dim, | |
| out_channels=out_chans * patch_size**2, | |
| kernel_size=1, | |
| padding=0, | |
| groups=out_chans, | |
| ) | |
| self.pixel_shuffle = nn.PixelShuffle(patch_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Tensor of shape (B, L, D) For ensembles, we have implicitly B = (B E). | |
| Returns: | |
| Tensor of shape (B C H W). | |
| Here | |
| - C equals out_chans | |
| - H == W == sqrt(L) x patch_size | |
| """ | |
| B, L, D = x.shape | |
| H_token = W_token = int(L**0.5) | |
| x = rearrange(x, "B L D -> (B L) 1 D") | |
| # (B L) 1 D -> (B L) C D | |
| x = self.perceiver(self.latent_queries.expand(B * L, -1, -1), x) | |
| x = rearrange(x, "(B H W) C D -> B (C D) H W", H=H_token, W=W_token) | |
| # B (C D) H_token W_token -> B (C patch_size patch_size) H_token W_token | |
| x = self.proj(x) | |
| # B (C patch_size patch_size) H_token W_token -> B C H W | |
| x = self.pixel_shuffle(x) | |
| return x | |