| """Copyright (c) Microsoft Corporation. Licensed under the MIT license. |
| |
| Parts of this code are inspired by |
| |
| https://github.com/microsoft/ClimaX/blob/6d5d354ffb4b91bb684f430b98e8f6f8af7c7f7c/src/climax/utils/pos_embed.py |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from timm.models.layers import to_2tuple |
|
|
| from aurora.model.fourier import FourierExpansion |
|
|
| __all__ = ["pos_scale_enc"] |
|
|
|
|
| def patch_root_area( |
| lat_min: torch.Tensor, |
| lon_min: torch.Tensor, |
| lat_max: torch.Tensor, |
| lon_max: torch.Tensor, |
| ) -> torch.Tensor: |
| """For a rectangular patch on a sphere, compute the square root of the area of the patch in |
| units km^2. The root is taken to return units of km, and thus stay scalable between different |
| resolutions. |
| |
| Args: |
| lat_min (torch.Tensor): Minimum latitutes of patches. |
| lon_min (torch.Tensor): Minimum longitudes of patches. |
| lat_max (torch.Tensor): Maximum latitudes of patches. |
| lon_max (torch.Tensor): Maximum longitudes of patches. |
| |
| Returns: |
| torch.Tensor: Square root of the area. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| assert (lat_max > lat_min).all(), f"lat_max - lat_min: {torch.min(lat_max - lat_min)}." |
| assert (lon_max > lon_min).all(), f"lon_max - lon_min: {torch.min(lon_max - lon_min)}." |
| assert (abs(lat_max) <= 90.0).all() and (abs(lat_min) <= 90.0).all() |
| assert (lon_max <= 360.0).all() and (lon_min <= 360.0).all() |
| assert (lon_max >= 0.0).all() and (lon_min >= 0.0).all() |
| area = ( |
| 6371**2 |
| * torch.pi |
| * (torch.sin(torch.deg2rad(lat_max)) - torch.sin(torch.deg2rad(lat_min))) |
| * (torch.deg2rad(lon_max) - torch.deg2rad(lon_min)) |
| ) |
|
|
| assert (area > 0.0).all() |
| return torch.sqrt(area) |
|
|
|
|
| def pos_scale_enc_grid( |
| encode_dim: int, |
| grid: torch.Tensor, |
| patch_dims: tuple, |
| pos_expansion: FourierExpansion, |
| scale_expansion: FourierExpansion, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute the position and scale encoding for a latitude-longitude grid. |
| |
| Requires batch dimensions in the input and returns a batch dimension. |
| |
| Args: |
| encode_dim (int): Output encoding dimension `D`. Must be a multiple of four: splits |
| across latitudes and longitudes and across sines and cosines. |
| grid (torch.Tensor): Latitude-longitude grid of dimensions `(B, 2, H, W)`. `grid[:, 0]` |
| should be the latitudes of `grid[:, 1]` should be the longitudes. |
| patch_dims (tuple): Patch dimensions. Different x-values and y-values are supported. |
| pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the |
| latitudes and longitudes. |
| scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the |
| patch areas. |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape |
| `(B, H/patch[0] * W/patch[1], D)`. |
| """ |
| assert encode_dim % 4 == 0 |
| assert grid.dim() == 4 |
|
|
| |
| |
| grid_h = F.avg_pool2d(grid[:, 0], patch_dims) |
| grid_w = F.avg_pool2d(grid[:, 1], patch_dims) |
|
|
| |
| |
| grid_lat_max = F.max_pool2d(grid[:, 0], patch_dims) |
| grid_lat_min = -F.max_pool2d(-grid[:, 0], patch_dims) |
| grid_lon_max = F.max_pool2d(grid[:, 1], patch_dims) |
| grid_lon_min = -F.max_pool2d(-grid[:, 1], patch_dims) |
| root_area = patch_root_area(grid_lat_min, grid_lon_min, grid_lat_max, grid_lon_max) |
|
|
| |
| |
| B = grid_h.shape[0] |
| encode_h = pos_expansion(grid_h.reshape(B, -1), encode_dim // 2) |
| encode_w = pos_expansion(grid_w.reshape(B, -1), encode_dim // 2) |
| pos_encode = torch.cat((encode_h, encode_w), axis=-1) |
|
|
| |
| scale_encode = scale_expansion(root_area.reshape(B, -1), encode_dim) |
|
|
| return pos_encode, scale_encode |
|
|
|
|
| def lat_lon_meshgrid(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor: |
| """Construct a meshgrid of latitude and longitude coordinates. |
| |
| `torch.meshgrid(*tensors, indexing="xy")` gives the same behavior as calling |
| `numpy.meshgrid(*arrays, indexing="ij")`:: |
| |
| lat = torch.tensor([1, 2, 3]) |
| lon = torch.tensor([4, 5, 6]) |
| grid_x, grid_y = torch.meshgrid(lat, lon, indexing='xy') |
| grid_x = tensor([[1, 2, 3], [1, 2, ,3], [1, 2, 3]]) |
| grid_y = tensor([[4, 4, 4], [5, 5, ,5], [6, 6, 6]]) |
| |
| Args: |
| lat (torch.Tensor): Vector of latitudes. |
| lon (torch.Tensor): Vector of longitudes. |
| |
| Returns: |
| torch.Tensor: Meshgrid of shape `(2, len(lat), len(lon))`. |
| """ |
| assert lat.dim() == 1 |
| assert lon.dim() == 1 |
|
|
| grid = torch.meshgrid(lat, lon, indexing="xy") |
| grid = torch.stack(grid, axis=0) |
| grid = grid.permute(0, 2, 1) |
|
|
| return grid |
|
|
|
|
| def pos_scale_enc( |
| encode_dim: int, |
| lat: torch.Tensor, |
| lon: torch.Tensor, |
| patch_dims: int | list | tuple, |
| pos_expansion: FourierExpansion, |
| scale_expansion: FourierExpansion, |
| ) -> torch.Tensor: |
| """Positional encoding of latitude-longitude data. |
| |
| Does not support batch dimensions in the input and does not return batch dimensions either. |
| |
| Args: |
| encode_dim (int): Output encoding dimension `D`. |
| lat (torch.Tensor): Latitudes, `H`. Can be either a vector or a matrix. |
| lon (torch.Tensor): Longitudes, `W`. Can be either a vector or a matrix. |
| patch_dims (Union[list, tuple]): Patch dimensions. Different x-values and y-values are |
| supported. |
| pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the |
| latitudes and longitudes. |
| scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the |
| patch areas. |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape |
| `(H/patch[0] * W/patch[1], D)`. |
| """ |
| if lat.dim() == lon.dim() == 1: |
| grid = lat_lon_meshgrid(lat, lon) |
| elif lat.dim() == lon.dim() == 2: |
| grid = torch.stack((lat, lon), dim=0) |
| else: |
| raise ValueError( |
| f"Latitudes and longitudes must either both be vectors or both be matrices, " |
| f"but have dimensionalities {lat.dim()} and {lon.dim()} respectively." |
| ) |
|
|
| grid = grid[None] |
|
|
| pos_encoding, scale_encoding = pos_scale_enc_grid( |
| encode_dim, |
| grid, |
| to_2tuple(patch_dims), |
| pos_expansion=pos_expansion, |
| scale_expansion=scale_expansion, |
| ) |
|
|
| return pos_encoding.squeeze(0), scale_encoding.squeeze(0) |
|
|