WeatherPEFT / aurora /model /fourier.py
bidulki-99's picture
Add files using upload-large-folder tool
9d66a40 verified
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
import math
import numpy as np
import torch
import torch.nn as nn
from aurora.area import area, radius_earth
__all__ = [
"FourierExpansion",
"pos_expansion",
"scale_expansion",
"lead_time_expansion",
"levels_expansion",
"absolute_time_expansion",
]
class FourierExpansion(nn.Module):
"""A Fourier series-style expansion into a high-dimensional space.
Attributes:
lower (float): Lower wavelength.
upper (float): Upper wavelength.
assert_range (bool): Assert that the encoded tensor is within the specified wavelength
range.
"""
def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
"""Initialise.
Args:
lower (float): Lower wavelength.
upper (float): Upper wavelength.
assert_range (bool, optional): Assert that the encoded tensor is within the specified
wavelength range. Defaults to `True`.
"""
super().__init__()
self.lower = lower
self.upper = upper
self.assert_range = assert_range
def forward(self, x: torch.Tensor, d: int) -> torch.Tensor:
"""Perform the expansion.
Adds a dimension of length `d` to the end of the shape of `x`.
Args:
x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must
lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
d (int): Dimensionality. Must be a multiple of two.
Raises:
AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not
within `[self.lower, self.upper]`.
ValueError: If `d` is not a multiple of two.
Returns:
torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`.
"""
# If the input is not within the configured range, the embedding might be ambiguous!
in_range = torch.logical_and(self.lower <= x.abs(), torch.all(x.abs() <= self.upper))
in_range_or_zero = torch.all(
torch.logical_or(in_range, x == 0)
) # Allow zeros to pass through.
if self.assert_range and not in_range_or_zero:
raise AssertionError(
f"The input tensor is not within the configured range"
f" `[{self.lower}, {self.upper}]`."
)
# We will use half of the dimensionality for `sin` and the other half for `cos`.
if not (d % 2 == 0):
raise ValueError("The dimensionality must be a multiple of two.")
# Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
x = x.double()
wavelengths = torch.logspace(
math.log10(self.lower),
math.log10(self.upper),
d // 2,
base=10,
device=x.device,
dtype=x.dtype,
)
prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths)
encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)
return encoding.float() # Cast to `float32` to avoid incompatibilities.
# Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in
# latitudes and longitudes.
_delta = 0.01 # Reasonable smallest delta in latitude and longitude
_min_patch_area: float = area(
torch.tensor(
[
# The smallest patches will be at the poles. Just use the north pole.
[90, 0],
[90, _delta],
[90 - _delta, _delta],
[90 - _delta, 0],
],
dtype=torch.float64,
)
).item()
_area_earth = 4 * np.pi * radius_earth * radius_earth
pos_expansion = FourierExpansion(_delta, 720)
""":class:`.FourierExpansion`: Fourier expansion for the encoding of latitudes and longitudes in
degrees."""
scale_expansion = FourierExpansion(_min_patch_area, _area_earth)
""":class:`.FourierExpansion`: Fourier expansion for the encoding of patch areas in squared
kilometers."""
lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3)
""":class:`.FourierExpansion`: Fourier expansion for the lead time encoding in hours."""
levels_expansion = FourierExpansion(0.01, 1e5)
""":class:`.FourierExpansion`: Fourier expansion for the pressure level encoding in hPa."""
absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False)
""":class:`.FourierExpansion`: Fourier expansion for the absolute time encoding in hours."""