WeatherPEFT / aurora /model /patchembed.py
bidulki-99's picture
Add files using upload-large-folder tool
9d66a40 verified
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import to_2tuple
__all__ = ["LevelPatchEmbed"]
class LevelPatchEmbed(nn.Module):
"""At either the surface or at a single pressure level, maps all variables into a single
embedding."""
def __init__(
self,
var_names: tuple[str, ...],
patch_size: int,
embed_dim: int,
history_size: int = 1,
norm_layer: Optional[nn.Module] = None,
flatten: bool = True,
) -> None:
"""Initialise.
Args:
var_names (tuple[str, ...]): Variables to embed.
patch_size (int): Patch size.
embed_dim (int): Embedding dimensionality.
history_size (int, optional): Number of history dimensions. Defaults to `1`.
norm_layer (torch.nn.Module, optional): Normalisation layer to be applied at the very
end. Defaults to no normalisation layer.
flatten (bool): At the end of the forward pass, flatten the two spatial dimensions
into a single dimension. See :meth:`LevelPatchEmbed.forward` for more details.
"""
super().__init__()
self.var_names = var_names
self.kernel_size = (history_size,) + to_2tuple(patch_size)
self.flatten = flatten
self.embed_dim = embed_dim
self.weights = nn.ParameterDict(
{
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every
# variable separately.
name: nn.Parameter(torch.empty(embed_dim, 1, *self.kernel_size))
for name in var_names
}
)
self.bias = nn.Parameter(torch.empty(embed_dim))
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.init_weights()
def init_weights(self) -> None:
"""Initialise weights."""
# Setting `a = sqrt(5)` in kaiming_uniform is the same as initialising with
# `uniform(-1/sqrt(k), 1/sqrt(k))`, where `k = weight.size(1) * prod(*kernel_size)`.
# For more details, see
#
# https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
#
for weight in self.weights.values():
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
# The following initialisation is taken from
#
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv3d
#
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(next(iter(self.weights.values())))
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor, var_names: tuple[str, ...]) -> torch.Tensor:
"""Run the embedding.
Args:
x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`.
var_names (tuple[str, ...]): Names of the variables in `x`. The length should be equal
to `V`.
Returns:
:class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened,
where `L = H * W / P^2`. Otherwise, the shape is `(B, D, H', W')`.
"""
B, V, T, H, W = x.shape
assert len(var_names) == V, f"{V} != {len(var_names)}."
assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}."
assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0."
assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0."
assert len(set(var_names)) == len(var_names), f"{var_names} contains duplicates."
# Select the weights of the variables and history dimensions that are present in the batch.
weight = torch.cat(
[
# (C_out, C_in, T, H, W)
self.weights[name][:, :, :T, ...]
for name in var_names
],
dim=1,
)
# Adjust the stride if history is smaller than maximum.
stride = (T,) + self.kernel_size[1:]
# The convolution maps (B, V, T, H, W) to (B, D, 1, H/P, W/P)
proj = F.conv3d(x, weight, self.bias, stride=stride)
if self.flatten:
proj = proj.reshape(B, self.embed_dim, -1) # (B, D, L)
proj = proj.transpose(1, 2) # (B, L, D)
x = self.norm(proj)
return x