File size: 6,798 Bytes
9d66a40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
from datetime import timedelta
import torch
from einops import rearrange
from torch import nn
from aurora.batch import Batch, Metadata
from aurora.model.fourier import levels_expansion
from aurora.model.perceiver import PerceiverResampler
from aurora.model.util import (
check_lat_lon_dtype,
init_weights,
unpatchify,
)
__all__ = ["Perceiver3DDecoder"]
class Perceiver3DDecoder(nn.Module):
"""Multi-scale multi-source multi-variable decoder based on the Perceiver architecture."""
def __init__(
self,
out_surf_vars: tuple[str, ...],
out_atmos_vars: tuple[str, ...],
patch_size: int = 4,
embed_dim: int = 1024,
depth: int = 1,
head_dim: int = 64,
num_heads: int = 8,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
perceiver_ln_eps: float = 1e-5,
) -> None:
"""Initialise.
Args:
surf_vars (tuple[str, ...]): All supported surface-level variables.
atmos_vars (tuple[str, ...]): All supported atmospheric variables.
patch_size (int, optional): Patch size. Defaults to `4`.
embed_dim (int, optional): Embedding dim.. Defaults to `1024`.
depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks.
Defaults to `1`.
head_dim (int, optional): Dimension of the attention heads used in the aggregation
blocks. Defaults to `64`.
num_heads (int, optional): Number of attention heads used in the aggregation blocks.
Defaults to `8`.
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimensionality.
Defaults to `4.0`.
drop_rate (float, optional): Drop-out rate for input patches. Defaults to `0.0`.
perceiver_ln_eps (float, optional): Layer norm. epsilon for the Perceiver blocks.
Defaults to `1e-5`.
"""
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.out_surf_vars = out_surf_vars
self.out_atmos_vars = out_atmos_vars
if out_surf_vars:
self.surf_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in out_surf_vars}
)
if out_atmos_vars:
self.atmos_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in out_atmos_vars}
)
self.level_decoder = PerceiverResampler(
latent_dim=embed_dim,
context_dim=embed_dim,
depth=depth,
head_dim=head_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
residual_latent=True,
ln_eps=perceiver_ln_eps,
)
self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim)
self.apply(init_weights)
def deaggregate_levels(self, level_embed: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Deaggregate pressure level information.
Args:
level_embed (torch.Tensor): Level embedding of shape `(B, L, C, D)`.
x (torch.Tensor): Aggregated input of shape `(B, L, C', D)`.
Returns:
torch.Tensor: Deaggregate output of shape `(B, L, C, D)`.
"""
B, L, C, D = level_embed.shape
level_embed = level_embed.flatten(0, 1) # (BxL, C, D)
x = x.flatten(0, 1) # (BxL, C', D)
_msg = f"Batch size mismatch. Found {level_embed.size(0)} and {x.size(0)}."
assert level_embed.size(0) == x.size(0), _msg
assert len(level_embed.shape) == 3, f"Expected 3 dims, found {level_embed.dims()}."
assert x.dim() == 3, f"Expected 3 dims, found {x.dim()}."
x = self.level_decoder(level_embed, x) # (BxL, C, D)
x = x.reshape(B, L, C, D)
return x
def forward(
self,
x: torch.Tensor,
batch: Batch,
patch_res: tuple[int, int, int],
lead_time: timedelta,
) -> Batch:
"""Forward pass of MultiScaleEncoder.
Args:
x (torch.Tensor): Backbone output of shape `(B, L, D)`.
batch (:class:`aurora.batch.Batch`): Batch to make predictions for.
patch_res (tuple[int, int, int]): Patch resolution
lead_time (timedelta): Lead time.
Returns:
:class:`aurora.batch.Batch`: Prediction for `batch`.
"""
surf_vars = self.out_surf_vars
atmos_vars = self.out_atmos_vars
atmos_levels = batch.metadata.atmos_levels
# Compress the latent dimension from the U-net skip concatenation.
B, L, D = x.shape
# Extract the lat, lon and convert to float32.
lat, lon = batch.metadata.lat, batch.metadata.lon
check_lat_lon_dtype(lat, lon)
lat, lon = lat.to(dtype=torch.float32), lon.to(dtype=torch.float32)
H, W = lat.shape[0], lon.shape[-1]
# Unwrap the latent level dimension.
x = rearrange(
x,
"B (C H W) D -> B (H W) C D",
C=patch_res[0],
H=patch_res[1],
W=patch_res[2],
)
surf_preds = None
atmos_preds = None
if surf_vars:
# Decode surface vars. Run the head for every surface-level variable.
x_surf = torch.stack([self.surf_heads[name](x[..., :1, :]) for name in surf_vars], dim=-1)
x_surf = x_surf.reshape(*x_surf.shape[:3], -1) # (B, L, 1, V_S*p*p)
surf_preds = unpatchify(x_surf, len(surf_vars), H , W, self.patch_size)
surf_preds = surf_preds.squeeze(2) # (B, V_S, H, W)
if atmos_vars:
# Embed the atmospheric levels.
atmos_levels_encode = levels_expansion(
torch.tensor(atmos_levels, device=x.device), self.embed_dim
).to(dtype=x.dtype)
levels_embed = self.atmos_levels_embed(atmos_levels_encode) # (C_A, D)
# De-aggregate the hidden levels into the physical levels.
levels_embed = levels_embed.expand(B, x.size(1), -1, -1)
x_atmos = self.deaggregate_levels(levels_embed, x[..., 1:, :]) # (B, L, C_A, D)
# Decode the atmospheric vars.
x_atmos = torch.stack([self.atmos_heads[name](x_atmos) for name in atmos_vars], dim=-1)
x_atmos = x_atmos.reshape(*x_atmos.shape[:3], -1) # (B, L, C_A, V_A*p*p)
atmos_preds = unpatchify(x_atmos, len(atmos_vars), H, W, self.patch_size)
return surf_preds, atmos_preds
|