Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import nn | |
| from ..utils import get_earth_position_index, trunc_normal_ | |
| class EarthAttention3D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 3D window attention with earth position bias. | |
| It supports both of shifted and non-shifted window. | |
| Args: | |
| dim (int): Number of input channels. | |
| input_resolution (tuple[int]): [pressure levels, latitude, longitude] | |
| window_size (tuple[int]): [pressure levels, latitude, longitude] | |
| num_heads (int): Number of attention heads. | |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| input_resolution, | |
| window_size, | |
| num_heads, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = window_size # Wpl, Wlat, Wlon | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.type_of_windows = (input_resolution[0] // window_size[0]) * ( | |
| input_resolution[1] // window_size[1] | |
| ) | |
| self.earth_position_bias_table = nn.Parameter( | |
| torch.zeros( | |
| (window_size[0] ** 2) | |
| * (window_size[1] ** 2) | |
| * (window_size[2] * 2 - 1), | |
| self.type_of_windows, | |
| num_heads, | |
| ) | |
| ) # Wpl**2 * Wlat**2 * Wlon*2-1, Npl//Wpl * Nlat//Wlat, nH | |
| earth_position_index = get_earth_position_index( | |
| window_size | |
| ) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon | |
| self.register_buffer("earth_position_index", earth_position_index) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.earth_position_bias_table = trunc_normal_( | |
| self.earth_position_bias_table, std=0.02 | |
| ) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x: torch.Tensor, mask=None): | |
| """ | |
| Args: | |
| x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) | |
| mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) | |
| """ | |
| B_, nW_, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) | |
| .permute(3, 0, 4, 1, 2, 5) | |
| ) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| q = q * self.scale | |
| attn = q @ k.transpose(-2, -1) | |
| earth_position_bias = self.earth_position_bias_table[ | |
| self.earth_position_index.view(-1) | |
| ].view( | |
| self.window_size[0] * self.window_size[1] * self.window_size[2], | |
| self.window_size[0] * self.window_size[1] * self.window_size[2], | |
| self.type_of_windows, | |
| -1, | |
| ) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon, num_pl*num_lat, nH | |
| earth_position_bias = earth_position_bias.permute( | |
| 3, 2, 0, 1 | |
| ).contiguous() # nH, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon | |
| attn = attn + earth_position_bias.unsqueeze(0) | |
| if mask is not None: | |
| nLon = mask.shape[0] | |
| attn = attn.view( | |
| B_ // nLon, nLon, self.num_heads, nW_, N, N | |
| ) + mask.unsqueeze(1).unsqueeze(0) | |
| attn = attn.view(-1, self.num_heads, nW_, N, N) | |
| attn = self.softmax(attn) | |
| else: | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class EarthAttention2D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 2D window attention with earth position bias. | |
| It supports both of shifted and non-shifted window. | |
| Args: | |
| dim (int): Number of input channels. | |
| input_resolution (tuple[int]): [latitude, longitude] | |
| window_size (tuple[int]): [latitude, longitude] | |
| num_heads (int): Number of attention heads. | |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| input_resolution, | |
| window_size, | |
| num_heads, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = window_size # Wlat, Wlon | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.type_of_windows = input_resolution[0] // window_size[0] | |
| self.earth_position_bias_table = nn.Parameter( | |
| torch.zeros( | |
| (window_size[0] ** 2) * (window_size[1] * 2 - 1), | |
| self.type_of_windows, | |
| num_heads, | |
| ) | |
| ) # Wlat**2 * Wlon*2-1, Nlat//Wlat, nH | |
| earth_position_index = get_earth_position_index( | |
| window_size, ndim=2 | |
| ) # Wlat*Wlon, Wlat*Wlon | |
| self.register_buffer("earth_position_index", earth_position_index) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.earth_position_bias_table = trunc_normal_( | |
| self.earth_position_bias_table, std=0.02 | |
| ) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x: torch.Tensor, mask=None): | |
| """ | |
| Args: | |
| x: input features with shape of (B * num_lon, num_lat, N, C) | |
| mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon) | |
| """ | |
| B_, nW_, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) | |
| .permute(3, 0, 4, 1, 2, 5) | |
| ) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| q = q * self.scale | |
| attn = q @ k.transpose(-2, -1) | |
| earth_position_bias = self.earth_position_bias_table[ | |
| self.earth_position_index.view(-1) | |
| ].view( | |
| self.window_size[0] * self.window_size[1], | |
| self.window_size[0] * self.window_size[1], | |
| self.type_of_windows, | |
| -1, | |
| ) # Wlat*Wlon, Wlat*Wlon, num_lat, nH | |
| earth_position_bias = earth_position_bias.permute( | |
| 3, 2, 0, 1 | |
| ).contiguous() # nH, num_lat, Wlat*Wlon, Wlat*Wlon | |
| attn = attn + earth_position_bias.unsqueeze(0) | |
| if mask is not None: | |
| nLon = mask.shape[0] | |
| attn = attn.view( | |
| B_ // nLon, nLon, self.num_heads, nW_, N, N | |
| ) + mask.unsqueeze(1).unsqueeze(0) | |
| attn = attn.view(-1, self.num_heads, nW_, N, N) | |
| attn = self.softmax(attn) | |
| else: | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |