WeatherPEFT / aurora /model /encoder.py
bidulki-99's picture
Add files using upload-large-folder tool
9d66a40 verified
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
from datetime import timedelta
import torch
from einops import rearrange
from torch import nn
from aurora.batch import Batch
from aurora.model.fourier import (
absolute_time_expansion,
lead_time_expansion,
levels_expansion,
pos_expansion,
scale_expansion,
)
from aurora.model.patchembed import LevelPatchEmbed
from aurora.model.perceiver import MLP, PerceiverResampler
from aurora.model.posencoding import pos_scale_enc
from aurora.model.util import (
check_lat_lon_dtype,
init_weights,
)
__all__ = ["Perceiver3DEncoder"]
class Perceiver3DEncoder(nn.Module):
"""Multi-scale multi-source multi-variable encoder based on the Perceiver architecture."""
def __init__(
self,
surf_vars: tuple[str, ...],
static_vars: tuple[str, ...] | None,
atmos_vars: tuple[str, ...],
patch_size: int = 4,
latent_levels: int = 8,
embed_dim: int = 1024,
num_heads: int = 16,
head_dim: int = 64,
drop_rate: float = 0.1,
depth: int = 2,
mlp_ratio: float = 4.0,
max_history_size: int = 2,
perceiver_ln_eps: float = 1e-5,
stabilise_level_agg: bool = False,
) -> None:
"""Initialise.
Args:
surf_vars (tuple[str, ...]): All supported surface-level variables.
static_vars (tuple[str, ...], optional): All supported static variables.
atmos_vars (tuple[str, ...]): All supported atmospheric variables.
patch_size (int, optional): Patch size. Defaults to `4`.
latent_levels (int): Number of latent pressure levels. Defaults to `8`.
embed_dim (int, optional): Embedding dim. used in the aggregation blocks. Defaults
to `1024`.
num_heads (int, optional): Number of attention heads used in aggregation blocks.
Defaults to `16`.
head_dim (int, optional): Dimension of attention heads used in aggregation blocks.
Defaults to `64`.
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.1`.
depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks.
Defaults to `2`.
mlp_ratio (float, optional): Ratio of hidden dimensionality to embedding dimensionality
for MLPs. Defaults to `4.0`.
max_history_size (int, optional): Maximum number of history steps to consider. Defaults
to `2`.
perceiver_ln_eps (float, optional): Epsilon value for layer normalisation in the
Perceiver. Defaults to 1e-5.
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
additional layer normalisation. Defaults to `False`.
"""
super().__init__()
self.drop_rate = drop_rate
self.embed_dim = embed_dim
self.patch_size = patch_size
# We treat the static variables as surface variables in the model.
surf_vars = surf_vars + static_vars if static_vars is not None else surf_vars
# Latent tokens
assert latent_levels > 1, "At least two latent levels are required."
self.latent_levels = latent_levels
# One latent level will be used by the surface level.
self.atmos_latents = nn.Parameter(torch.randn(latent_levels - 1, embed_dim))
# Learnable embedding to encode the surface level.
self.surf_level_encoding = nn.Parameter(torch.randn(embed_dim))
self.surf_mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=drop_rate)
self.surf_norm = nn.LayerNorm(embed_dim)
# Position, scale, and time embeddings
self.pos_embed = nn.Linear(embed_dim, embed_dim)
self.scale_embed = nn.Linear(embed_dim, embed_dim)
self.lead_time_embed = nn.Linear(embed_dim, embed_dim)
self.absolute_time_embed = nn.Linear(embed_dim, embed_dim)
self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim)
# Patch embeddings
assert max_history_size > 0, "At least one history step is required."
self.surf_token_embeds = LevelPatchEmbed(
surf_vars,
patch_size,
embed_dim,
max_history_size,
)
self.atmos_token_embeds = LevelPatchEmbed(
atmos_vars,
patch_size,
embed_dim,
max_history_size,
)
# Learnable pressure level aggregation
self.level_agg = PerceiverResampler(
latent_dim=embed_dim,
context_dim=embed_dim,
depth=depth,
head_dim=head_dim,
num_heads=num_heads,
drop=drop_rate,
mlp_ratio=mlp_ratio,
ln_eps=perceiver_ln_eps,
ln_k_q=stabilise_level_agg,
)
# Drop patches after encoding.
self.pos_drop = nn.Dropout(p=drop_rate)
self.apply(init_weights)
# Initialize the latents like in the Huggingface implementation of the Perceiver:
#
# https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/perceiver/modeling_perceiver.py#L628
#
torch.nn.init.trunc_normal_(self.atmos_latents, std=0.02)
torch.nn.init.trunc_normal_(self.surf_level_encoding, std=0.02)
def aggregate_levels(self, x: torch.Tensor) -> torch.Tensor:
"""Aggregate pressure level information.
Args:
x (torch.Tensor): Tensor of shape `(B, C_A, L, D)` where `C_A` refers to the number
of pressure levels.
Returns:
torch.Tensor: Tensor of shape `(B, C, L, D)` where `C` is the number of
aggregated pressure levels.
"""
B, _, L, _ = x.shape
latents = self.atmos_latents.to(dtype=x.dtype)
latents = latents.unsqueeze(1).expand(B, -1, L, -1) # (C_A, D) to (B, C_A, L, D)
x = torch.einsum("bcld->blcd", x)
x = x.flatten(0, 1) # (B * L, C_A, D)
latents = torch.einsum("bcld->blcd", latents)
latents = latents.flatten(0, 1) # (B * L, C_A, D)
x = self.level_agg(latents, x) # (B * L, C, D)
x = x.unflatten(dim=0, sizes=(B, L)) # (B, L, C, D)
x = torch.einsum("blcd->bcld", x) # (B, C, L, D)
return x
def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
"""Peform encoding.
Args:
batch (:class:`.Batch`): Batch to encode.
lead_time (timedelta): Lead time.
Returns:
torch.Tensor: Encoding of shape `(B, L, D)`.
"""
surf_vars = tuple(batch.surf_vars.keys())
static_vars = tuple(batch.static_vars.keys())
atmos_vars = tuple(batch.atmos_vars.keys())
atmos_levels = batch.metadata.atmos_levels
x_surf = torch.stack(tuple(batch.surf_vars.values()), dim=2)
x_static = torch.stack(tuple(batch.static_vars.values()), dim=2)
x_atmos = torch.stack(tuple(batch.atmos_vars.values()), dim=2)
B, T, _, C, H, W = x_atmos.size()
assert x_surf.shape[:2] == (B, T), f"Expected shape {(B, T)}, got {x_surf.shape[:2]}."
if static_vars is None:
assert x_static is None, "Static variables given, but not configured."
else:
assert x_static is not None, "Static variables not given."
x_static = x_static.expand((B, T, -1, -1, -1))
x_surf = torch.cat((x_surf, x_static), dim=2) # (B, T, V_S + V_Static, H, W)
surf_vars = surf_vars + static_vars
lat, lon = batch.metadata.lat, batch.metadata.lon
check_lat_lon_dtype(lat, lon)
lat, lon = lat.to(dtype=torch.float32), lon.to(dtype=torch.float32)
assert lat.shape[0] == H and lon.shape[-1] == W
# Patch embed the surface level.
x_surf = rearrange(x_surf, "b t v h w -> b v t h w")
x_surf = self.surf_token_embeds(x_surf, surf_vars) # (B, L, D)
dtype = x_surf.dtype # When using mixed precision, we need to keep track of the dtype.
# Patch embed the atmospheric levels.
x_atmos = rearrange(x_atmos, "b t v c h w -> (b c) v t h w")
x_atmos = self.atmos_token_embeds(x_atmos, atmos_vars)
x_atmos = rearrange(x_atmos, "(b c) l d -> b c l d", b=B, c=C)
# Add surface level encoding. This helps the model distinguish between surface and
# atmospheric levels.
x_surf = x_surf + self.surf_level_encoding[None, None, :].to(dtype=dtype)
# Since the surface level is not aggregated, we add a Perceiver-like MLP only.
x_surf = x_surf + self.surf_norm(self.surf_mlp(x_surf))
# Add atmospheric pressure encoding of shape (C_A, D) and subsequent embedding.
atmos_levels_tensor = torch.tensor(atmos_levels, device=x_atmos.device)
atmos_levels_encode = levels_expansion(atmos_levels_tensor, self.embed_dim).to(dtype=dtype)
atmos_levels_embed = self.atmos_levels_embed(atmos_levels_encode)[None, :, None, :]
x_atmos = x_atmos + atmos_levels_embed # (B, C_A, L, D)
# Aggregate over pressure levels.
x_atmos = self.aggregate_levels(x_atmos) # (B, C_A, L, D) to (B, C, L, D)
# Concatenate the surface level with the amospheric levels.
x = torch.cat((x_surf.unsqueeze(1), x_atmos), dim=1)
# Add position and scale embeddings to the 3D tensor.
pos_encode, scale_encode = pos_scale_enc(
self.embed_dim,
lat,
lon,
self.patch_size,
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)
# Encodings are (L, D).
pos_encode = self.pos_embed(pos_encode[None, None, :].to(dtype=dtype))
scale_encode = self.scale_embed(scale_encode[None, None, :].to(dtype=dtype))
x = x + pos_encode + scale_encode
# pos_encode_pre, scale_encode_pre = pos_scale_enc(
# self.embed_dim,
# torch.linspace(89.625, -89.625, 240).to(device=lat.device,dtype=torch.float32),
# torch.linspace(0, 360, 241)[:-1].to(device=lon.device,dtype=torch.float32),
# (self.patch_size, self.patch_size),
# pos_expansion=pos_expansion,
# scale_expansion=scale_expansion,
# )
# pos_encode_pre = self.pos_embed(pos_encode_pre[None, None, :].to(dtype=dtype))
# scale_encode_pre = self.scale_embed(scale_encode_pre[None, None, :].to(dtype=dtype))
# x = x + pos_encode_pre + scale_encode_pre
# Flatten the tokens.
x = x.reshape(B, -1, self.embed_dim) # (B, C + 1, L, D) to (B, L', D)
# Add lead time embedding.
lead_hours = lead_time.total_seconds() / 3600
lead_times = lead_hours * torch.ones(B, dtype=dtype, device=x.device)
lead_time_encode = lead_time_expansion(lead_times, self.embed_dim).to(dtype=dtype)
lead_time_emb = self.lead_time_embed(lead_time_encode) # (B, D)
x = x + lead_time_emb.unsqueeze(1) # (B, L', D) + (B, 1, D)
# Add absolute time embedding.
absolute_times_list = [t.timestamp() / 3600 for t in batch.metadata.time] # Times in hours
absolute_times = torch.tensor(absolute_times_list, dtype=torch.float32, device=x.device)
absolute_time_encode = absolute_time_expansion(absolute_times, self.embed_dim)
absolute_time_embed = self.absolute_time_embed(absolute_time_encode.to(dtype=dtype))
x = x + absolute_time_embed.unsqueeze(1) # (B, L, D) + (B, 1, D)
x = self.pos_drop(x)
return x