File size: 11,757 Bytes
9d66a40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 | """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
from datetime import timedelta
import torch
from einops import rearrange
from torch import nn
from aurora.batch import Batch
from aurora.model.fourier import (
absolute_time_expansion,
lead_time_expansion,
levels_expansion,
pos_expansion,
scale_expansion,
)
from aurora.model.patchembed import LevelPatchEmbed
from aurora.model.perceiver import MLP, PerceiverResampler
from aurora.model.posencoding import pos_scale_enc
from aurora.model.util import (
check_lat_lon_dtype,
init_weights,
)
__all__ = ["Perceiver3DEncoder"]
class Perceiver3DEncoder(nn.Module):
"""Multi-scale multi-source multi-variable encoder based on the Perceiver architecture."""
def __init__(
self,
surf_vars: tuple[str, ...],
static_vars: tuple[str, ...] | None,
atmos_vars: tuple[str, ...],
patch_size: int = 4,
latent_levels: int = 8,
embed_dim: int = 1024,
num_heads: int = 16,
head_dim: int = 64,
drop_rate: float = 0.1,
depth: int = 2,
mlp_ratio: float = 4.0,
max_history_size: int = 2,
perceiver_ln_eps: float = 1e-5,
stabilise_level_agg: bool = False,
) -> None:
"""Initialise.
Args:
surf_vars (tuple[str, ...]): All supported surface-level variables.
static_vars (tuple[str, ...], optional): All supported static variables.
atmos_vars (tuple[str, ...]): All supported atmospheric variables.
patch_size (int, optional): Patch size. Defaults to `4`.
latent_levels (int): Number of latent pressure levels. Defaults to `8`.
embed_dim (int, optional): Embedding dim. used in the aggregation blocks. Defaults
to `1024`.
num_heads (int, optional): Number of attention heads used in aggregation blocks.
Defaults to `16`.
head_dim (int, optional): Dimension of attention heads used in aggregation blocks.
Defaults to `64`.
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.1`.
depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks.
Defaults to `2`.
mlp_ratio (float, optional): Ratio of hidden dimensionality to embedding dimensionality
for MLPs. Defaults to `4.0`.
max_history_size (int, optional): Maximum number of history steps to consider. Defaults
to `2`.
perceiver_ln_eps (float, optional): Epsilon value for layer normalisation in the
Perceiver. Defaults to 1e-5.
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
additional layer normalisation. Defaults to `False`.
"""
super().__init__()
self.drop_rate = drop_rate
self.embed_dim = embed_dim
self.patch_size = patch_size
# We treat the static variables as surface variables in the model.
surf_vars = surf_vars + static_vars if static_vars is not None else surf_vars
# Latent tokens
assert latent_levels > 1, "At least two latent levels are required."
self.latent_levels = latent_levels
# One latent level will be used by the surface level.
self.atmos_latents = nn.Parameter(torch.randn(latent_levels - 1, embed_dim))
# Learnable embedding to encode the surface level.
self.surf_level_encoding = nn.Parameter(torch.randn(embed_dim))
self.surf_mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=drop_rate)
self.surf_norm = nn.LayerNorm(embed_dim)
# Position, scale, and time embeddings
self.pos_embed = nn.Linear(embed_dim, embed_dim)
self.scale_embed = nn.Linear(embed_dim, embed_dim)
self.lead_time_embed = nn.Linear(embed_dim, embed_dim)
self.absolute_time_embed = nn.Linear(embed_dim, embed_dim)
self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim)
# Patch embeddings
assert max_history_size > 0, "At least one history step is required."
self.surf_token_embeds = LevelPatchEmbed(
surf_vars,
patch_size,
embed_dim,
max_history_size,
)
self.atmos_token_embeds = LevelPatchEmbed(
atmos_vars,
patch_size,
embed_dim,
max_history_size,
)
# Learnable pressure level aggregation
self.level_agg = PerceiverResampler(
latent_dim=embed_dim,
context_dim=embed_dim,
depth=depth,
head_dim=head_dim,
num_heads=num_heads,
drop=drop_rate,
mlp_ratio=mlp_ratio,
ln_eps=perceiver_ln_eps,
ln_k_q=stabilise_level_agg,
)
# Drop patches after encoding.
self.pos_drop = nn.Dropout(p=drop_rate)
self.apply(init_weights)
# Initialize the latents like in the Huggingface implementation of the Perceiver:
#
# https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/perceiver/modeling_perceiver.py#L628
#
torch.nn.init.trunc_normal_(self.atmos_latents, std=0.02)
torch.nn.init.trunc_normal_(self.surf_level_encoding, std=0.02)
def aggregate_levels(self, x: torch.Tensor) -> torch.Tensor:
"""Aggregate pressure level information.
Args:
x (torch.Tensor): Tensor of shape `(B, C_A, L, D)` where `C_A` refers to the number
of pressure levels.
Returns:
torch.Tensor: Tensor of shape `(B, C, L, D)` where `C` is the number of
aggregated pressure levels.
"""
B, _, L, _ = x.shape
latents = self.atmos_latents.to(dtype=x.dtype)
latents = latents.unsqueeze(1).expand(B, -1, L, -1) # (C_A, D) to (B, C_A, L, D)
x = torch.einsum("bcld->blcd", x)
x = x.flatten(0, 1) # (B * L, C_A, D)
latents = torch.einsum("bcld->blcd", latents)
latents = latents.flatten(0, 1) # (B * L, C_A, D)
x = self.level_agg(latents, x) # (B * L, C, D)
x = x.unflatten(dim=0, sizes=(B, L)) # (B, L, C, D)
x = torch.einsum("blcd->bcld", x) # (B, C, L, D)
return x
def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
"""Peform encoding.
Args:
batch (:class:`.Batch`): Batch to encode.
lead_time (timedelta): Lead time.
Returns:
torch.Tensor: Encoding of shape `(B, L, D)`.
"""
surf_vars = tuple(batch.surf_vars.keys())
static_vars = tuple(batch.static_vars.keys())
atmos_vars = tuple(batch.atmos_vars.keys())
atmos_levels = batch.metadata.atmos_levels
x_surf = torch.stack(tuple(batch.surf_vars.values()), dim=2)
x_static = torch.stack(tuple(batch.static_vars.values()), dim=2)
x_atmos = torch.stack(tuple(batch.atmos_vars.values()), dim=2)
B, T, _, C, H, W = x_atmos.size()
assert x_surf.shape[:2] == (B, T), f"Expected shape {(B, T)}, got {x_surf.shape[:2]}."
if static_vars is None:
assert x_static is None, "Static variables given, but not configured."
else:
assert x_static is not None, "Static variables not given."
x_static = x_static.expand((B, T, -1, -1, -1))
x_surf = torch.cat((x_surf, x_static), dim=2) # (B, T, V_S + V_Static, H, W)
surf_vars = surf_vars + static_vars
lat, lon = batch.metadata.lat, batch.metadata.lon
check_lat_lon_dtype(lat, lon)
lat, lon = lat.to(dtype=torch.float32), lon.to(dtype=torch.float32)
assert lat.shape[0] == H and lon.shape[-1] == W
# Patch embed the surface level.
x_surf = rearrange(x_surf, "b t v h w -> b v t h w")
x_surf = self.surf_token_embeds(x_surf, surf_vars) # (B, L, D)
dtype = x_surf.dtype # When using mixed precision, we need to keep track of the dtype.
# Patch embed the atmospheric levels.
x_atmos = rearrange(x_atmos, "b t v c h w -> (b c) v t h w")
x_atmos = self.atmos_token_embeds(x_atmos, atmos_vars)
x_atmos = rearrange(x_atmos, "(b c) l d -> b c l d", b=B, c=C)
# Add surface level encoding. This helps the model distinguish between surface and
# atmospheric levels.
x_surf = x_surf + self.surf_level_encoding[None, None, :].to(dtype=dtype)
# Since the surface level is not aggregated, we add a Perceiver-like MLP only.
x_surf = x_surf + self.surf_norm(self.surf_mlp(x_surf))
# Add atmospheric pressure encoding of shape (C_A, D) and subsequent embedding.
atmos_levels_tensor = torch.tensor(atmos_levels, device=x_atmos.device)
atmos_levels_encode = levels_expansion(atmos_levels_tensor, self.embed_dim).to(dtype=dtype)
atmos_levels_embed = self.atmos_levels_embed(atmos_levels_encode)[None, :, None, :]
x_atmos = x_atmos + atmos_levels_embed # (B, C_A, L, D)
# Aggregate over pressure levels.
x_atmos = self.aggregate_levels(x_atmos) # (B, C_A, L, D) to (B, C, L, D)
# Concatenate the surface level with the amospheric levels.
x = torch.cat((x_surf.unsqueeze(1), x_atmos), dim=1)
# Add position and scale embeddings to the 3D tensor.
pos_encode, scale_encode = pos_scale_enc(
self.embed_dim,
lat,
lon,
self.patch_size,
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)
# Encodings are (L, D).
pos_encode = self.pos_embed(pos_encode[None, None, :].to(dtype=dtype))
scale_encode = self.scale_embed(scale_encode[None, None, :].to(dtype=dtype))
x = x + pos_encode + scale_encode
# pos_encode_pre, scale_encode_pre = pos_scale_enc(
# self.embed_dim,
# torch.linspace(89.625, -89.625, 240).to(device=lat.device,dtype=torch.float32),
# torch.linspace(0, 360, 241)[:-1].to(device=lon.device,dtype=torch.float32),
# (self.patch_size, self.patch_size),
# pos_expansion=pos_expansion,
# scale_expansion=scale_expansion,
# )
# pos_encode_pre = self.pos_embed(pos_encode_pre[None, None, :].to(dtype=dtype))
# scale_encode_pre = self.scale_embed(scale_encode_pre[None, None, :].to(dtype=dtype))
# x = x + pos_encode_pre + scale_encode_pre
# Flatten the tokens.
x = x.reshape(B, -1, self.embed_dim) # (B, C + 1, L, D) to (B, L', D)
# Add lead time embedding.
lead_hours = lead_time.total_seconds() / 3600
lead_times = lead_hours * torch.ones(B, dtype=dtype, device=x.device)
lead_time_encode = lead_time_expansion(lead_times, self.embed_dim).to(dtype=dtype)
lead_time_emb = self.lead_time_embed(lead_time_encode) # (B, D)
x = x + lead_time_emb.unsqueeze(1) # (B, L', D) + (B, 1, D)
# Add absolute time embedding.
absolute_times_list = [t.timestamp() / 3600 for t in batch.metadata.time] # Times in hours
absolute_times = torch.tensor(absolute_times_list, dtype=torch.float32, device=x.device)
absolute_time_encode = absolute_time_expansion(absolute_times, self.embed_dim)
absolute_time_embed = self.absolute_time_embed(absolute_time_encode.to(dtype=dtype))
x = x + absolute_time_embed.unsqueeze(1) # (B, L, D) + (B, 1, D)
x = self.pos_drop(x)
return x
|