WeatherPEFT / aurora /model /posencoding.py
bidulki-99's picture
Add files using upload-large-folder tool
9d66a40 verified
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.
Parts of this code are inspired by
https://github.com/microsoft/ClimaX/blob/6d5d354ffb4b91bb684f430b98e8f6f8af7c7f7c/src/climax/utils/pos_embed.py
"""
import torch
import torch.nn.functional as F
from timm.models.layers import to_2tuple
from aurora.model.fourier import FourierExpansion
__all__ = ["pos_scale_enc"]
def patch_root_area(
lat_min: torch.Tensor,
lon_min: torch.Tensor,
lat_max: torch.Tensor,
lon_max: torch.Tensor,
) -> torch.Tensor:
"""For a rectangular patch on a sphere, compute the square root of the area of the patch in
units km^2. The root is taken to return units of km, and thus stay scalable between different
resolutions.
Args:
lat_min (torch.Tensor): Minimum latitutes of patches.
lon_min (torch.Tensor): Minimum longitudes of patches.
lat_max (torch.Tensor): Maximum latitudes of patches.
lon_max (torch.Tensor): Maximum longitudes of patches.
Returns:
torch.Tensor: Square root of the area.
"""
# Calculate area of latitude-longitude grid using the following formula. Phis are latitudes
# and thetas are longitudes.
#
# area = R**2 * pi * (sin(phi_1) - sin(phi_2)) * (theta_1 - theta_2)
#
# Taken from
#
# https://www.johndcook.com/blog/2023/02/21/sphere-grid-area/
#
assert (lat_max > lat_min).all(), f"lat_max - lat_min: {torch.min(lat_max - lat_min)}."
assert (lon_max > lon_min).all(), f"lon_max - lon_min: {torch.min(lon_max - lon_min)}."
assert (abs(lat_max) <= 90.0).all() and (abs(lat_min) <= 90.0).all()
assert (lon_max <= 360.0).all() and (lon_min <= 360.0).all()
assert (lon_max >= 0.0).all() and (lon_min >= 0.0).all()
area = (
6371**2
* torch.pi
* (torch.sin(torch.deg2rad(lat_max)) - torch.sin(torch.deg2rad(lat_min)))
* (torch.deg2rad(lon_max) - torch.deg2rad(lon_min))
)
assert (area > 0.0).all()
return torch.sqrt(area)
def pos_scale_enc_grid(
encode_dim: int,
grid: torch.Tensor,
patch_dims: tuple,
pos_expansion: FourierExpansion,
scale_expansion: FourierExpansion,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute the position and scale encoding for a latitude-longitude grid.
Requires batch dimensions in the input and returns a batch dimension.
Args:
encode_dim (int): Output encoding dimension `D`. Must be a multiple of four: splits
across latitudes and longitudes and across sines and cosines.
grid (torch.Tensor): Latitude-longitude grid of dimensions `(B, 2, H, W)`. `grid[:, 0]`
should be the latitudes of `grid[:, 1]` should be the longitudes.
patch_dims (tuple): Patch dimensions. Different x-values and y-values are supported.
pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
latitudes and longitudes.
scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
patch areas.
Returns:
tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape
`(B, H/patch[0] * W/patch[1], D)`.
"""
assert encode_dim % 4 == 0
assert grid.dim() == 4
# Take the 2D pooled values of the mesh. This is the same as subsequent 1D pooling over the
# x-axis and then ove the y-axis.
grid_h = F.avg_pool2d(grid[:, 0], patch_dims)
grid_w = F.avg_pool2d(grid[:, 1], patch_dims)
# Compute the square root of the area of the patches specified by the latitude-longitude
# grid.
grid_lat_max = F.max_pool2d(grid[:, 0], patch_dims)
grid_lat_min = -F.max_pool2d(-grid[:, 0], patch_dims)
grid_lon_max = F.max_pool2d(grid[:, 1], patch_dims)
grid_lon_min = -F.max_pool2d(-grid[:, 1], patch_dims)
root_area = patch_root_area(grid_lat_min, grid_lon_min, grid_lat_max, grid_lon_max)
# Use half of dimensions for the latitudes of the midpoints of the patches and the other
# half for the longitudes. Before computing the encodings, flatten over the spatial dimensions.
B = grid_h.shape[0]
encode_h = pos_expansion(grid_h.reshape(B, -1), encode_dim // 2) # (B, L, D/2)
encode_w = pos_expansion(grid_w.reshape(B, -1), encode_dim // 2) # (B, L, D/2)
pos_encode = torch.cat((encode_h, encode_w), axis=-1) # (B, L, D)
# No need to split things up for the scale encoding.
scale_encode = scale_expansion(root_area.reshape(B, -1), encode_dim) # (B, L, D)
return pos_encode, scale_encode
def lat_lon_meshgrid(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor:
"""Construct a meshgrid of latitude and longitude coordinates.
`torch.meshgrid(*tensors, indexing="xy")` gives the same behavior as calling
`numpy.meshgrid(*arrays, indexing="ij")`::
lat = torch.tensor([1, 2, 3])
lon = torch.tensor([4, 5, 6])
grid_x, grid_y = torch.meshgrid(lat, lon, indexing='xy')
grid_x = tensor([[1, 2, 3], [1, 2, ,3], [1, 2, 3]])
grid_y = tensor([[4, 4, 4], [5, 5, ,5], [6, 6, 6]])
Args:
lat (torch.Tensor): Vector of latitudes.
lon (torch.Tensor): Vector of longitudes.
Returns:
torch.Tensor: Meshgrid of shape `(2, len(lat), len(lon))`.
"""
assert lat.dim() == 1
assert lon.dim() == 1
grid = torch.meshgrid(lat, lon, indexing="xy")
grid = torch.stack(grid, axis=0)
grid = grid.permute(0, 2, 1)
return grid
def pos_scale_enc(
encode_dim: int,
lat: torch.Tensor,
lon: torch.Tensor,
patch_dims: int | list | tuple,
pos_expansion: FourierExpansion,
scale_expansion: FourierExpansion,
) -> torch.Tensor:
"""Positional encoding of latitude-longitude data.
Does not support batch dimensions in the input and does not return batch dimensions either.
Args:
encode_dim (int): Output encoding dimension `D`.
lat (torch.Tensor): Latitudes, `H`. Can be either a vector or a matrix.
lon (torch.Tensor): Longitudes, `W`. Can be either a vector or a matrix.
patch_dims (Union[list, tuple]): Patch dimensions. Different x-values and y-values are
supported.
pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
latitudes and longitudes.
scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
patch areas.
Returns:
tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape
`(H/patch[0] * W/patch[1], D)`.
"""
if lat.dim() == lon.dim() == 1:
grid = lat_lon_meshgrid(lat, lon)
elif lat.dim() == lon.dim() == 2:
grid = torch.stack((lat, lon), dim=0)
else:
raise ValueError(
f"Latitudes and longitudes must either both be vectors or both be matrices, "
f"but have dimensionalities {lat.dim()} and {lon.dim()} respectively."
)
grid = grid[None] # Add batch dimension.
pos_encoding, scale_encoding = pos_scale_enc_grid(
encode_dim,
grid,
to_2tuple(patch_dims),
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)
return pos_encoding.squeeze(0), scale_encoding.squeeze(0) # Return without batch dimension.