|
|
import math |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
from collections.abc import Sequence |
|
|
import warnings |
|
|
|
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
|
warnings.warn( |
|
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
|
"The distribution of values may be incorrect.", |
|
|
stacklevel=2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
u1 = norm_cdf((a - mean) / std) |
|
|
u2 = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * u1 - 1, 2 * u2 - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
|
tensor.add_(mean) |
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): |
|
|
|
|
|
r"""Cut & paste from timm master |
|
|
Fills the input Tensor with values drawn from a truncated |
|
|
normal distribution. The values are effectively drawn from the |
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
|
with values outside :math:`[a, b]` redrawn until they are within |
|
|
the bounds. The method used for generating the random values works |
|
|
best when :math:`a \leq \text{mean} \leq b`. |
|
|
|
|
|
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are |
|
|
applied while sampling the normal with mean/std applied, therefore a, b args |
|
|
should be adjusted to match the range of mean, std args. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
return _trunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
act_layer=nn.GELU, |
|
|
drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def drop_path( |
|
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True |
|
|
): |
|
|
"""Cut & paste from timm master |
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
|
'survival rate' as the argument. |
|
|
|
|
|
""" |
|
|
if drop_prob == 0.0 or not training: |
|
|
return x |
|
|
keep_prob = 1 - drop_prob |
|
|
shape = (x.shape[0],) + (1,) * ( |
|
|
x.ndim - 1 |
|
|
) |
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
|
if keep_prob > 0.0 and scale_by_keep: |
|
|
random_tensor.div_(keep_prob) |
|
|
return x * random_tensor |
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
"""Cut & paste from timm master |
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
""" |
|
|
|
|
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): |
|
|
super(DropPath, self).__init__() |
|
|
self.drop_prob = drop_prob |
|
|
self.scale_by_keep = scale_by_keep |
|
|
|
|
|
def forward(self, x): |
|
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"drop_prob={round(self.drop_prob,3):0.3f}" |
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed2D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
2D Image to Patch Embedding. |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): Image size. |
|
|
patch_size (tuple[int]): Patch token size. |
|
|
in_chans (int): Number of input image channels. |
|
|
embed_dim(int): Number of projection output channels. |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: None |
|
|
""" |
|
|
|
|
|
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): |
|
|
super().__init__() |
|
|
self.img_size = img_size |
|
|
height, width = img_size |
|
|
h_patch_size, w_path_size = patch_size |
|
|
padding_left = padding_right = padding_top = padding_bottom = 0 |
|
|
|
|
|
h_remainder = height % h_patch_size |
|
|
w_remainder = width % w_path_size |
|
|
|
|
|
if h_remainder: |
|
|
h_pad = h_patch_size - h_remainder |
|
|
padding_top = h_pad // 2 |
|
|
padding_bottom = int(h_pad - padding_top) |
|
|
|
|
|
if w_remainder: |
|
|
w_pad = w_path_size - w_remainder |
|
|
padding_left = w_pad // 2 |
|
|
padding_right = int(w_pad - padding_left) |
|
|
|
|
|
self.pad = nn.ConstantPad3d( |
|
|
(padding_left, padding_right, padding_top, padding_bottom), value=0 |
|
|
) |
|
|
self.proj = nn.Conv2d( |
|
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size |
|
|
) |
|
|
if norm_layer is not None: |
|
|
self.norm = norm_layer(embed_dim) |
|
|
else: |
|
|
self.norm = None |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
B, C, H, W = x.shape |
|
|
x = self.pad(x) |
|
|
x = self.proj(x) |
|
|
if self.norm is not None: |
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
class PatchEmbed3D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
3D Image to Patch Embedding. |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): Image size. |
|
|
patch_size (tuple[int]): Patch token size. |
|
|
in_chans (int): Number of input image channels. |
|
|
embed_dim(int): Number of projection output channels. |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: None |
|
|
""" |
|
|
|
|
|
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): |
|
|
super().__init__() |
|
|
self.img_size = img_size |
|
|
level, height, width = img_size |
|
|
l_patch_size, h_patch_size, w_patch_size = patch_size |
|
|
padding_left = ( |
|
|
padding_right |
|
|
) = padding_top = padding_bottom = padding_front = padding_back = 0 |
|
|
|
|
|
l_remainder = level % l_patch_size |
|
|
h_remainder = height % l_patch_size |
|
|
w_remainder = width % w_patch_size |
|
|
|
|
|
if l_remainder: |
|
|
l_pad = l_patch_size - l_remainder |
|
|
padding_front = l_pad // 2 |
|
|
padding_back = l_pad - padding_front |
|
|
if h_remainder: |
|
|
h_pad = h_patch_size - h_remainder |
|
|
padding_top = h_pad // 2 |
|
|
padding_bottom = h_pad - padding_top |
|
|
if w_remainder: |
|
|
w_pad = w_patch_size - w_remainder |
|
|
padding_left = w_pad // 2 |
|
|
padding_right = w_pad - padding_left |
|
|
|
|
|
self.pad = nn.ConstantPad3d( |
|
|
( |
|
|
padding_left, |
|
|
padding_right, |
|
|
padding_top, |
|
|
padding_bottom, |
|
|
padding_front, |
|
|
padding_back, |
|
|
), |
|
|
value=0 |
|
|
) |
|
|
self.proj = nn.Conv3d( |
|
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size |
|
|
) |
|
|
if norm_layer is not None: |
|
|
self.norm = norm_layer(embed_dim) |
|
|
else: |
|
|
self.norm = None |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
B, C, L, H, W = x.shape |
|
|
x = self.pad(x) |
|
|
x = self.proj(x) |
|
|
if self.norm: |
|
|
x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) |
|
|
return x |
|
|
|
|
|
|
|
|
class PatchRecovery2D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
Patch Embedding Recovery to 2D Image. |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): Lat, Lon |
|
|
patch_size (tuple[int]): Lat, Lon |
|
|
in_chans (int): Number of input channels. |
|
|
out_chans (int): Number of output channels. |
|
|
""" |
|
|
|
|
|
def __init__(self, img_size, patch_size, in_chans, out_chans): |
|
|
super().__init__() |
|
|
self.img_size = img_size |
|
|
self.conv = nn.ConvTranspose2d(in_chans, out_chans, patch_size, patch_size) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self.conv(x) |
|
|
_, _, H, W = output.shape |
|
|
h_pad = H - self.img_size[0] |
|
|
w_pad = W - self.img_size[1] |
|
|
|
|
|
padding_top = h_pad // 2 |
|
|
padding_bottom = int(h_pad - padding_top) |
|
|
|
|
|
padding_left = w_pad // 2 |
|
|
padding_right = int(w_pad - padding_left) |
|
|
|
|
|
return output[ |
|
|
:, :, padding_top : H - padding_bottom, padding_left : W - padding_right |
|
|
] |
|
|
|
|
|
|
|
|
class PatchRecovery3D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
Patch Embedding Recovery to 3D Image. |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): Pl, Lat, Lon |
|
|
patch_size (tuple[int]): Pl, Lat, Lon |
|
|
in_chans (int): Number of input channels. |
|
|
out_chans (int): Number of output channels. |
|
|
""" |
|
|
|
|
|
def __init__(self, img_size, patch_size, in_chans, out_chans): |
|
|
super().__init__() |
|
|
self.img_size = img_size |
|
|
self.conv = nn.ConvTranspose3d(in_chans, out_chans, patch_size, patch_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
output = self.conv(x) |
|
|
_, _, Pl, Lat, Lon = output.shape |
|
|
|
|
|
pl_pad = Pl - self.img_size[0] |
|
|
lat_pad = Lat - self.img_size[1] |
|
|
lon_pad = Lon - self.img_size[2] |
|
|
|
|
|
padding_front = pl_pad // 2 |
|
|
padding_back = pl_pad - padding_front |
|
|
|
|
|
padding_top = lat_pad // 2 |
|
|
padding_bottom = lat_pad - padding_top |
|
|
|
|
|
padding_left = lon_pad // 2 |
|
|
padding_right = lon_pad - padding_left |
|
|
|
|
|
return output[ |
|
|
:, |
|
|
:, |
|
|
padding_front : Pl - padding_back, |
|
|
padding_top : Lat - padding_bottom, |
|
|
padding_left : Lon - padding_right, |
|
|
] |
|
|
|
|
|
class UpSample3D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
3D Up-sampling operation. |
|
|
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py |
|
|
|
|
|
Args: |
|
|
in_dim (int): Number of input channels. |
|
|
out_dim (int): Number of output channels. |
|
|
input_resolution (tuple[int]): [pressure levels, latitude, longitude] |
|
|
output_resolution (tuple[int]): [pressure levels, latitude, longitude] |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, out_dim, input_resolution, output_resolution): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False) |
|
|
self.linear2 = nn.Linear(out_dim, out_dim, bias=False) |
|
|
self.norm = nn.LayerNorm(out_dim) |
|
|
self.input_resolution = input_resolution |
|
|
self.output_resolution = output_resolution |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): (B, N, C) |
|
|
""" |
|
|
B, N, C = x.shape |
|
|
in_pl, in_lat, in_lon = self.input_resolution |
|
|
out_pl, out_lat, out_lon = self.output_resolution |
|
|
|
|
|
x = self.linear1(x) |
|
|
x = x.reshape(B, in_pl, in_lat, in_lon, 2, 2, C // 2).permute( |
|
|
0, 1, 2, 4, 3, 5, 6 |
|
|
) |
|
|
x = x.reshape(B, in_pl, in_lat * 2, in_lon * 2, -1) |
|
|
|
|
|
pad_h = in_lat * 2 - out_lat |
|
|
pad_w = in_lon * 2 - out_lon |
|
|
|
|
|
pad_top = pad_h // 2 |
|
|
pad_bottom = pad_h - pad_top |
|
|
|
|
|
pad_left = pad_w // 2 |
|
|
pad_right = pad_w - pad_left |
|
|
|
|
|
x = x[ |
|
|
:, |
|
|
:out_pl, |
|
|
pad_top : 2 * in_lat - pad_bottom, |
|
|
pad_left : 2 * in_lon - pad_right, |
|
|
:, |
|
|
] |
|
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3], x.shape[4]) |
|
|
x = self.norm(x) |
|
|
x = self.linear2(x) |
|
|
return x |
|
|
|
|
|
class UpSample2D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
2D Up-sampling operation. |
|
|
|
|
|
Args: |
|
|
in_dim (int): Number of input channels. |
|
|
out_dim (int): Number of output channels. |
|
|
input_resolution (tuple[int]): [latitude, longitude] |
|
|
output_resolution (tuple[int]): [latitude, longitude] |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, out_dim, input_resolution, output_resolution): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False) |
|
|
self.linear2 = nn.Linear(out_dim, out_dim, bias=False) |
|
|
self.norm = nn.LayerNorm(out_dim) |
|
|
self.input_resolution = input_resolution |
|
|
self.output_resolution = output_resolution |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): (B, N, C) |
|
|
""" |
|
|
B, N, C = x.shape |
|
|
in_lat, in_lon = self.input_resolution |
|
|
out_lat, out_lon = self.output_resolution |
|
|
|
|
|
x = self.linear1(x) |
|
|
x = x.reshape(B, in_lat, in_lon, 2, 2, C // 2).permute(0, 1, 3, 2, 4, 5) |
|
|
x = x.reshape(B, in_lat * 2, in_lon * 2, -1) |
|
|
|
|
|
pad_h = in_lat * 2 - out_lat |
|
|
pad_w = in_lon * 2 - out_lon |
|
|
|
|
|
pad_top = pad_h // 2 |
|
|
pad_bottom = pad_h - pad_top |
|
|
|
|
|
pad_left = pad_w // 2 |
|
|
pad_right = pad_w - pad_left |
|
|
|
|
|
x = x[ |
|
|
:, pad_top : 2 * in_lat - pad_bottom, pad_left : 2 * in_lon - pad_right, : |
|
|
] |
|
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) |
|
|
x = self.norm(x) |
|
|
x = self.linear2(x) |
|
|
return x |
|
|
|
|
|
class DownSample3D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
3D Down-sampling operation |
|
|
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py |
|
|
|
|
|
Args: |
|
|
in_dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): [pressure levels, latitude, longitude] |
|
|
output_resolution (tuple[int]): [pressure levels, latitude, longitude] |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, input_resolution, output_resolution): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False) |
|
|
self.norm = nn.LayerNorm(4 * in_dim) |
|
|
self.input_resolution = input_resolution |
|
|
self.output_resolution = output_resolution |
|
|
|
|
|
in_pl, in_lat, in_lon = self.input_resolution |
|
|
out_pl, out_lat, out_lon = self.output_resolution |
|
|
|
|
|
h_pad = out_lat * 2 - in_lat |
|
|
w_pad = out_lon * 2 - in_lon |
|
|
|
|
|
pad_top = h_pad // 2 |
|
|
pad_bottom = h_pad - pad_top |
|
|
|
|
|
pad_left = w_pad // 2 |
|
|
pad_right = w_pad - pad_left |
|
|
|
|
|
pad_front = pad_back = 0 |
|
|
|
|
|
self.pad = nn.ConstantPad3d( |
|
|
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back), value=0 |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, C = x.shape |
|
|
in_pl, in_lat, in_lon = self.input_resolution |
|
|
out_pl, out_lat, out_lon = self.output_resolution |
|
|
x = x.reshape(B, in_pl, in_lat, in_lon, C) |
|
|
|
|
|
|
|
|
x = self.pad(x.permute(0, -1, 1, 2, 3)).permute(0, 2, 3, 4, 1) |
|
|
x = x.reshape(B, in_pl, out_lat, 2, out_lon, 2, C).permute(0, 1, 2, 4, 3, 5, 6) |
|
|
x = x.reshape(B, out_pl * out_lat * out_lon, 4 * C) |
|
|
|
|
|
x = self.norm(x) |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DownSample2D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
2D Down-sampling operation |
|
|
|
|
|
Args: |
|
|
in_dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): [latitude, longitude] |
|
|
output_resolution (tuple[int]): [latitude, longitude] |
|
|
""" |
|
|
|
|
|
def __init__(self, in_dim, input_resolution, output_resolution): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False) |
|
|
self.norm = nn.LayerNorm(4 * in_dim) |
|
|
self.input_resolution = input_resolution |
|
|
self.output_resolution = output_resolution |
|
|
|
|
|
in_lat, in_lon = self.input_resolution |
|
|
out_lat, out_lon = self.output_resolution |
|
|
|
|
|
h_pad = out_lat * 2 - in_lat |
|
|
w_pad = out_lon * 2 - in_lon |
|
|
|
|
|
pad_top = h_pad // 2 |
|
|
pad_bottom = h_pad - pad_top |
|
|
|
|
|
pad_left = w_pad // 2 |
|
|
pad_right = w_pad - pad_left |
|
|
|
|
|
self.pad = nn.ConstantPad3d((pad_left, pad_right, pad_top, pad_bottom), value=0) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
B, N, C = x.shape |
|
|
in_lat, in_lon = self.input_resolution |
|
|
out_lat, out_lon = self.output_resolution |
|
|
x = x.reshape(B, in_lat, in_lon, C) |
|
|
|
|
|
|
|
|
x = self.pad(x.permute(0, -1, 1, 2)).permute(0, 2, 3, 1) |
|
|
x = x.reshape(B, out_lat, 2, out_lon, 2, C).permute(0, 1, 3, 2, 4, 5) |
|
|
x = x.reshape(B, out_lat * out_lon, 4 * C) |
|
|
|
|
|
x = self.norm(x) |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_earth_position_index(window_size, ndim=3): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
This function construct the position index to reuse symmetrical parameters of the position bias. |
|
|
implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py |
|
|
|
|
|
Args: |
|
|
window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude] |
|
|
ndim (int): dimension of tensor, 3 or 2 |
|
|
|
|
|
Returns: |
|
|
position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon] |
|
|
""" |
|
|
if ndim == 3: |
|
|
win_pl, win_lat, win_lon = window_size |
|
|
elif ndim == 2: |
|
|
win_lat, win_lon = window_size |
|
|
|
|
|
if ndim == 3: |
|
|
|
|
|
coords_zi = torch.arange(win_pl) |
|
|
|
|
|
coords_zj = -torch.arange(win_pl) * win_pl |
|
|
|
|
|
|
|
|
coords_hi = torch.arange(win_lat) |
|
|
|
|
|
coords_hj = -torch.arange(win_lat) * win_lat |
|
|
|
|
|
|
|
|
coords_w = torch.arange(win_lon) |
|
|
|
|
|
|
|
|
if ndim == 3: |
|
|
coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) |
|
|
coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) |
|
|
elif ndim == 2: |
|
|
coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w])) |
|
|
coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w])) |
|
|
coords_flatten_1 = torch.flatten(coords_1, 1) |
|
|
coords_flatten_2 = torch.flatten(coords_2, 1) |
|
|
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] |
|
|
coords = coords.permute(1, 2, 0).contiguous() |
|
|
|
|
|
|
|
|
if ndim == 3: |
|
|
coords[:, :, 2] += win_lon - 1 |
|
|
coords[:, :, 1] *= 2 * win_lon - 1 |
|
|
coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat |
|
|
elif ndim == 2: |
|
|
coords[:, :, 1] += win_lon - 1 |
|
|
coords[:, :, 0] *= 2 * win_lon - 1 |
|
|
|
|
|
|
|
|
position_index = coords.sum(-1) |
|
|
|
|
|
return position_index |
|
|
|
|
|
|
|
|
def get_pad3d(input_resolution, window_size): |
|
|
""" |
|
|
Args: |
|
|
input_resolution (tuple[int]): (Pl, Lat, Lon) |
|
|
window_size (tuple[int]): (Pl, Lat, Lon) |
|
|
|
|
|
Returns: |
|
|
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back) |
|
|
""" |
|
|
Pl, Lat, Lon = input_resolution |
|
|
win_pl, win_lat, win_lon = window_size |
|
|
|
|
|
padding_left = ( |
|
|
padding_right |
|
|
) = padding_top = padding_bottom = padding_front = padding_back = 0 |
|
|
pl_remainder = Pl % win_pl |
|
|
lat_remainder = Lat % win_lat |
|
|
lon_remainder = Lon % win_lon |
|
|
|
|
|
if pl_remainder: |
|
|
pl_pad = win_pl - pl_remainder |
|
|
padding_front = pl_pad // 2 |
|
|
padding_back = pl_pad - padding_front |
|
|
if lat_remainder: |
|
|
lat_pad = win_lat - lat_remainder |
|
|
padding_top = lat_pad // 2 |
|
|
padding_bottom = lat_pad - padding_top |
|
|
if lon_remainder: |
|
|
lon_pad = win_lon - lon_remainder |
|
|
padding_left = lon_pad // 2 |
|
|
padding_right = lon_pad - padding_left |
|
|
|
|
|
return ( |
|
|
padding_left, |
|
|
padding_right, |
|
|
padding_top, |
|
|
padding_bottom, |
|
|
padding_front, |
|
|
padding_back, |
|
|
) |
|
|
|
|
|
def get_pad2d(input_resolution, window_size): |
|
|
""" |
|
|
Args: |
|
|
input_resolution (tuple[int]): Lat, Lon |
|
|
window_size (tuple[int]): Lat, Lon |
|
|
|
|
|
Returns: |
|
|
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom) |
|
|
""" |
|
|
input_resolution = [2] + list(input_resolution) |
|
|
window_size = [2] + list(window_size) |
|
|
padding = get_pad3d(input_resolution, window_size) |
|
|
return padding[: 4] |
|
|
|
|
|
def crop3d(x: torch.Tensor, resolution): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): B, C, Pl, Lat, Lon |
|
|
resolution (tuple[int]): Pl, Lat, Lon |
|
|
""" |
|
|
_, _, Pl, Lat, Lon = x.shape |
|
|
pl_pad = Pl - resolution[0] |
|
|
lat_pad = Lat - resolution[1] |
|
|
lon_pad = Lon - resolution[2] |
|
|
|
|
|
padding_front = pl_pad // 2 |
|
|
padding_back = pl_pad - padding_front |
|
|
|
|
|
padding_top = lat_pad // 2 |
|
|
padding_bottom = lat_pad - padding_top |
|
|
|
|
|
padding_left = lon_pad // 2 |
|
|
padding_right = lon_pad - padding_left |
|
|
return x[ |
|
|
:, |
|
|
:, |
|
|
padding_front : Pl - padding_back, |
|
|
padding_top : Lat - padding_bottom, |
|
|
padding_left : Lon - padding_right, |
|
|
] |
|
|
|
|
|
def crop2d(x: torch.Tensor, resolution): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): B, C, Lat, Lon |
|
|
resolution (tuple[int]): Lat, Lon |
|
|
""" |
|
|
_, _, Lat, Lon = x.shape |
|
|
lat_pad = Lat - resolution[0] |
|
|
lon_pad = Lon - resolution[1] |
|
|
|
|
|
padding_top = lat_pad // 2 |
|
|
padding_bottom = lat_pad - padding_top |
|
|
|
|
|
padding_left = lon_pad // 2 |
|
|
padding_right = lon_pad - padding_left |
|
|
|
|
|
return x[ |
|
|
:, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right |
|
|
] |
|
|
|
|
|
|
|
|
class EarthAttention2D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
2D window attention with earth position bias. |
|
|
It supports both of shifted and non-shifted window. |
|
|
|
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): [latitude, longitude] |
|
|
window_size (tuple[int]): [latitude, longitude] |
|
|
num_heads (int): Number of attention heads. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
|
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
|
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
input_resolution, |
|
|
window_size, |
|
|
num_heads, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.window_size = window_size |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
|
|
self.type_of_windows = input_resolution[0] // window_size[0] |
|
|
|
|
|
self.earth_position_bias_table = nn.Parameter( |
|
|
torch.zeros( |
|
|
(window_size[0] ** 2) * (window_size[1] * 2 - 1), |
|
|
self.type_of_windows, |
|
|
num_heads, |
|
|
) |
|
|
) |
|
|
|
|
|
earth_position_index = get_earth_position_index( |
|
|
window_size, ndim=2 |
|
|
) |
|
|
self.register_buffer("earth_position_index", earth_position_index) |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
self.earth_position_bias_table = trunc_normal_( |
|
|
self.earth_position_bias_table, std=0.02 |
|
|
) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask=None): |
|
|
""" |
|
|
Args: |
|
|
x: input features with shape of (B * num_lon, num_lat, N, C) |
|
|
mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon) |
|
|
""" |
|
|
B_, nW_, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) |
|
|
.permute(3, 0, 4, 1, 2, 5) |
|
|
) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
q = q * self.scale |
|
|
attn = q @ k.transpose(-2, -1) |
|
|
|
|
|
earth_position_bias = self.earth_position_bias_table[ |
|
|
self.earth_position_index.view(-1) |
|
|
].view( |
|
|
self.window_size[0] * self.window_size[1], |
|
|
self.window_size[0] * self.window_size[1], |
|
|
self.type_of_windows, |
|
|
-1, |
|
|
) |
|
|
earth_position_bias = earth_position_bias.permute( |
|
|
3, 2, 0, 1 |
|
|
).contiguous() |
|
|
attn = attn + earth_position_bias.unsqueeze(0) |
|
|
|
|
|
if mask is not None: |
|
|
nLon = mask.shape[0] |
|
|
attn = attn.view( |
|
|
B_ // nLon, nLon, self.num_heads, nW_, N, N |
|
|
) + mask.unsqueeze(1).unsqueeze(0) |
|
|
attn = attn.view(-1, self.num_heads, nW_, N, N) |
|
|
attn = self.softmax(attn) |
|
|
else: |
|
|
attn = self.softmax(attn) |
|
|
|
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class EarthAttention3D(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
3D window attention with earth position bias. |
|
|
It supports both of shifted and non-shifted window. |
|
|
|
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): [pressure levels, latitude, longitude] |
|
|
window_size (tuple[int]): [pressure levels, latitude, longitude] |
|
|
num_heads (int): Number of attention heads. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
|
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
|
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
input_resolution, |
|
|
window_size, |
|
|
num_heads, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.window_size = window_size |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
|
|
|
self.type_of_windows = (input_resolution[0] // window_size[0]) * ( |
|
|
input_resolution[1] // window_size[1] |
|
|
) |
|
|
|
|
|
self.earth_position_bias_table = nn.Parameter( |
|
|
torch.zeros( |
|
|
(window_size[0] ** 2) |
|
|
* (window_size[1] ** 2) |
|
|
* (window_size[2] * 2 - 1), |
|
|
self.type_of_windows, |
|
|
num_heads, |
|
|
) |
|
|
) |
|
|
|
|
|
earth_position_index = get_earth_position_index( |
|
|
window_size |
|
|
) |
|
|
self.register_buffer("earth_position_index", earth_position_index) |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
self.earth_position_bias_table = trunc_normal_( |
|
|
self.earth_position_bias_table, std=0.02 |
|
|
) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask=None): |
|
|
""" |
|
|
Args: |
|
|
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) |
|
|
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) |
|
|
""" |
|
|
B_, nW_, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) |
|
|
.permute(3, 0, 4, 1, 2, 5) |
|
|
) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
q = q * self.scale |
|
|
attn = q @ k.transpose(-2, -1) |
|
|
|
|
|
earth_position_bias = self.earth_position_bias_table[ |
|
|
self.earth_position_index.view(-1) |
|
|
].view( |
|
|
self.window_size[0] * self.window_size[1] * self.window_size[2], |
|
|
self.window_size[0] * self.window_size[1] * self.window_size[2], |
|
|
self.type_of_windows, |
|
|
-1, |
|
|
) |
|
|
earth_position_bias = earth_position_bias.permute( |
|
|
3, 2, 0, 1 |
|
|
).contiguous() |
|
|
attn = attn + earth_position_bias.unsqueeze(0) |
|
|
|
|
|
if mask is not None: |
|
|
nLon = mask.shape[0] |
|
|
attn = attn.view( |
|
|
B_ // nLon, nLon, self.num_heads, nW_, N, N |
|
|
) + mask.unsqueeze(1).unsqueeze(0) |
|
|
attn = attn.view(-1, self.num_heads, nW_, N, N) |
|
|
attn = self.softmax(attn) |
|
|
else: |
|
|
attn = self.softmax(attn) |
|
|
|
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Transformer3DBlock(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
3D Transformer Block |
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): Input resulotion. |
|
|
num_heads (int): Number of attention heads. |
|
|
window_size (tuple[int]): Window size [pressure levels, latitude, longitude]. |
|
|
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude]. |
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
|
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
input_resolution, |
|
|
num_heads, |
|
|
window_size=None, |
|
|
shift_size=None, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
window_size = (2, 6, 12) if window_size is None else window_size |
|
|
shift_size = (1, 3, 6) if shift_size is None else shift_size |
|
|
self.dim = dim |
|
|
self.input_resolution = input_resolution |
|
|
self.num_heads = num_heads |
|
|
self.window_size = window_size |
|
|
self.shift_size = shift_size |
|
|
self.mlp_ratio = mlp_ratio |
|
|
|
|
|
self.norm1 = norm_layer(dim) |
|
|
padding = get_pad3d(input_resolution, window_size) |
|
|
self.pad = nn.ConstantPad3d(padding, value=0) |
|
|
|
|
|
pad_resolution = list(input_resolution) |
|
|
pad_resolution[0] += padding[-1] + padding[-2] |
|
|
pad_resolution[1] += padding[2] + padding[3] |
|
|
pad_resolution[2] += padding[0] + padding[1] |
|
|
|
|
|
self.attn = EarthAttention3D( |
|
|
dim=dim, |
|
|
input_resolution=pad_resolution, |
|
|
window_size=window_size, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp( |
|
|
in_features=dim, |
|
|
hidden_features=mlp_hidden_dim, |
|
|
act_layer=act_layer, |
|
|
drop=drop, |
|
|
) |
|
|
|
|
|
shift_pl, shift_lat, shift_lon = self.shift_size |
|
|
self.roll = shift_pl and shift_lon and shift_lat |
|
|
|
|
|
if self.roll: |
|
|
attn_mask = get_shift_window_mask(pad_resolution, window_size, shift_size) |
|
|
else: |
|
|
attn_mask = None |
|
|
|
|
|
self.register_buffer("attn_mask", attn_mask) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
Pl, Lat, Lon = self.input_resolution |
|
|
B, L, C = x.shape |
|
|
|
|
|
shortcut = x |
|
|
x = self.norm1(x) |
|
|
x = x.view(B, Pl, Lat, Lon, C) |
|
|
|
|
|
|
|
|
x = self.pad(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) |
|
|
|
|
|
_, Pl_pad, Lat_pad, Lon_pad, _ = x.shape |
|
|
|
|
|
shift_pl, shift_lat, shift_lon = self.shift_size |
|
|
if self.roll: |
|
|
shifted_x = torch.roll( |
|
|
x, shifts=(-shift_pl, -shift_lat, -shift_lat), dims=(1, 2, 3) |
|
|
) |
|
|
x_windows = window_partition(shifted_x, self.window_size) |
|
|
|
|
|
else: |
|
|
shifted_x = x |
|
|
x_windows = window_partition(shifted_x, self.window_size) |
|
|
|
|
|
|
|
|
win_pl, win_lat, win_lon = self.window_size |
|
|
x_windows = x_windows.view( |
|
|
x_windows.shape[0], x_windows.shape[1], win_pl * win_lat * win_lon, C |
|
|
) |
|
|
|
|
|
|
|
|
attn_windows = self.attn( |
|
|
x_windows, mask=self.attn_mask |
|
|
) |
|
|
|
|
|
attn_windows = attn_windows.view( |
|
|
attn_windows.shape[0], attn_windows.shape[1], win_pl, win_lat, win_lon, C |
|
|
) |
|
|
|
|
|
if self.roll: |
|
|
shifted_x = window_reverse( |
|
|
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad |
|
|
) |
|
|
|
|
|
x = torch.roll( |
|
|
shifted_x, shifts=(shift_pl, shift_lat, shift_lon), dims=(1, 2, 3) |
|
|
) |
|
|
else: |
|
|
shifted_x = window_reverse( |
|
|
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad |
|
|
) |
|
|
x = shifted_x |
|
|
|
|
|
|
|
|
x = crop3d(x.permute(0, 4, 1, 2, 3), self.input_resolution).permute( |
|
|
0, 2, 3, 4, 1 |
|
|
) |
|
|
|
|
|
x = x.reshape(B, Pl * Lat * Lon, C) |
|
|
x = shortcut + self.drop_path(x) |
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def window_partition(x: torch.Tensor, window_size, ndim=3): |
|
|
""" |
|
|
Args: |
|
|
x: (B, Pl, Lat, Lon, C) or (B, Lat, Lon, C) |
|
|
window_size (tuple[int]): [win_pl, win_lat, win_lon] or [win_lat, win_lon] |
|
|
ndim (int): dimension of window (3 or 2) |
|
|
|
|
|
Returns: |
|
|
windows: (B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C) or (B*num_lon, num_lat, win_lat, win_lon, C) |
|
|
""" |
|
|
if ndim == 3: |
|
|
B, Pl, Lat, Lon, C = x.shape |
|
|
win_pl, win_lat, win_lon = window_size |
|
|
x = x.view( |
|
|
B, Pl // win_pl, win_pl, Lat // win_lat, win_lat, Lon // win_lon, win_lon, C |
|
|
) |
|
|
windows = ( |
|
|
x.permute(0, 5, 1, 3, 2, 4, 6, 7) |
|
|
.contiguous() |
|
|
.view(-1, (Pl // win_pl) * (Lat // win_lat), win_pl, win_lat, win_lon, C) |
|
|
) |
|
|
return windows |
|
|
elif ndim == 2: |
|
|
B, Lat, Lon, C = x.shape |
|
|
win_lat, win_lon = window_size |
|
|
x = x.view(B, Lat // win_lat, win_lat, Lon // win_lon, win_lon, C) |
|
|
windows = ( |
|
|
x.permute(0, 3, 1, 2, 4, 5) |
|
|
.contiguous() |
|
|
.view(-1, (Lat // win_lat), win_lat, win_lon, C) |
|
|
) |
|
|
return windows |
|
|
|
|
|
|
|
|
def window_reverse(windows, window_size, Pl=1, Lat=1, Lon=1, ndim=3): |
|
|
""" |
|
|
Args: |
|
|
windows: (B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C) or (B*num_lon, num_lat, win_lat, win_lon, C) |
|
|
window_size (tuple[int]): [win_pl, win_lat, win_lon] or [win_lat, win_lon] |
|
|
Pl (int): pressure levels |
|
|
Lat (int): latitude |
|
|
Lon (int): longitude |
|
|
ndim (int): dimension of window (3 or 2) |
|
|
|
|
|
Returns: |
|
|
x: (B, Pl, Lat, Lon, C) or (B, Lat, Lon, C) |
|
|
""" |
|
|
if ndim == 3: |
|
|
win_pl, win_lat, win_lon = window_size |
|
|
B = int(windows.shape[0] / (Lon / win_lon)) |
|
|
x = windows.view( |
|
|
B, |
|
|
Lon // win_lon, |
|
|
Pl // win_pl, |
|
|
Lat // win_lat, |
|
|
win_pl, |
|
|
win_lat, |
|
|
win_lon, |
|
|
-1, |
|
|
) |
|
|
x = x.permute(0, 2, 4, 3, 5, 1, 6, 7).contiguous().view(B, Pl, Lat, Lon, -1) |
|
|
return x |
|
|
elif ndim == 2: |
|
|
win_lat, win_lon = window_size |
|
|
B = int(windows.shape[0] / (Lon / win_lon)) |
|
|
x = windows.view(B, Lon // win_lon, Lat // win_lat, win_lat, win_lon, -1) |
|
|
x = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(B, Lat, Lon, -1) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): |
|
|
""" |
|
|
Along the longitude dimension, the leftmost and rightmost indices are actually close to each other. |
|
|
If half windows apper at both leftmost and rightmost positions, they are dircetly merged into one window. |
|
|
Args: |
|
|
input_resolution (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude] |
|
|
window_size (tuple[int]): Window size [pressure levels, latitude, longitude] or [latitude, longitude] |
|
|
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude] or [latitude, longitude] |
|
|
ndim (int): dimension of window (3 or 2) |
|
|
|
|
|
Returns: |
|
|
attn_mask: (n_lon, n_pl*n_lat, win_pl*win_lat*win_lon, win_pl*win_lat*win_lon) or (n_lon, n_lat, win_lat*win_lon, win_lat*win_lon) |
|
|
""" |
|
|
if ndim == 3: |
|
|
Pl, Lat, Lon = input_resolution |
|
|
win_pl, win_lat, win_lon = window_size |
|
|
shift_pl, shift_lat, shift_lon = shift_size |
|
|
|
|
|
img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1)) |
|
|
elif ndim == 2: |
|
|
Lat, Lon = input_resolution |
|
|
win_lat, win_lon = window_size |
|
|
shift_lat, shift_lon = shift_size |
|
|
|
|
|
img_mask = torch.zeros((1, Lat, Lon + shift_lon, 1)) |
|
|
|
|
|
if ndim == 3: |
|
|
pl_slices = ( |
|
|
slice(0, -win_pl), |
|
|
slice(-win_pl, -shift_pl), |
|
|
slice(-shift_pl, None), |
|
|
) |
|
|
lat_slices = ( |
|
|
slice(0, -win_lat), |
|
|
slice(-win_lat, -shift_lat), |
|
|
slice(-shift_lat, None), |
|
|
) |
|
|
lon_slices = ( |
|
|
slice(0, -win_lon), |
|
|
slice(-win_lon, -shift_lon), |
|
|
slice(-shift_lon, None), |
|
|
) |
|
|
|
|
|
cnt = 0 |
|
|
if ndim == 3: |
|
|
for pl in pl_slices: |
|
|
for lat in lat_slices: |
|
|
for lon in lon_slices: |
|
|
img_mask[:, pl, lat, lon, :] = cnt |
|
|
cnt += 1 |
|
|
img_mask = img_mask[:, :, :, :Lon, :] |
|
|
elif ndim == 2: |
|
|
for lat in lat_slices: |
|
|
for lon in lon_slices: |
|
|
img_mask[:, lat, lon, :] = cnt |
|
|
cnt += 1 |
|
|
img_mask = img_mask[:, :, :Lon, :] |
|
|
|
|
|
mask_windows = window_partition( |
|
|
img_mask, window_size, ndim=ndim |
|
|
) |
|
|
if ndim == 3: |
|
|
win_total = win_pl * win_lat * win_lon |
|
|
elif ndim == 2: |
|
|
win_total = win_lat * win_lon |
|
|
mask_windows = mask_windows.view( |
|
|
mask_windows.shape[0], mask_windows.shape[1], win_total |
|
|
) |
|
|
attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(3) |
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
|
|
attn_mask == 0, float(0.0) |
|
|
) |
|
|
return attn_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Transformer2DBlock(nn.Module): |
|
|
""" |
|
|
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
2D Transformer Block |
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): Input resulotion. |
|
|
num_heads (int): Number of attention heads. |
|
|
window_size (tuple[int]): Window size [latitude, longitude]. |
|
|
shift_size (tuple[int]): Shift size for SW-MSA [latitude, longitude]. |
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
|
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
input_resolution, |
|
|
num_heads, |
|
|
window_size=None, |
|
|
shift_size=None, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
window_size = (6, 12) if window_size is None else window_size |
|
|
shift_size = (3, 6) if shift_size is None else shift_size |
|
|
self.dim = dim |
|
|
self.input_resolution = input_resolution |
|
|
self.num_heads = num_heads |
|
|
self.window_size = window_size |
|
|
self.shift_size = shift_size |
|
|
self.mlp_ratio = mlp_ratio |
|
|
|
|
|
self.norm1 = norm_layer(dim) |
|
|
padding = get_pad2d(input_resolution, window_size) |
|
|
self.pad = nn.ConstantPad3d(padding, value=0) |
|
|
|
|
|
pad_resolution = list(input_resolution) |
|
|
pad_resolution[0] += padding[2] + padding[3] |
|
|
pad_resolution[1] += padding[0] + padding[1] |
|
|
|
|
|
self.attn = EarthAttention2D( |
|
|
dim=dim, |
|
|
input_resolution=pad_resolution, |
|
|
window_size=window_size, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp( |
|
|
in_features=dim, |
|
|
hidden_features=mlp_hidden_dim, |
|
|
act_layer=act_layer, |
|
|
drop=drop, |
|
|
) |
|
|
|
|
|
shift_lat, shift_lon = self.shift_size |
|
|
self.roll = shift_lon and shift_lat |
|
|
|
|
|
if self.roll: |
|
|
attn_mask = get_shift_window_mask( |
|
|
pad_resolution, window_size, shift_size, ndim=2 |
|
|
) |
|
|
else: |
|
|
attn_mask = None |
|
|
|
|
|
self.register_buffer("attn_mask", attn_mask) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
Lat, Lon = self.input_resolution |
|
|
B, L, C = x.shape |
|
|
|
|
|
shortcut = x |
|
|
x = self.norm1(x) |
|
|
x = x.view(B, Lat, Lon, C) |
|
|
|
|
|
|
|
|
x = self.pad(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) |
|
|
_, Lat_pad, Lon_pad, _ = x.shape |
|
|
|
|
|
shift_lat, shift_lon = self.shift_size |
|
|
if self.roll: |
|
|
shifted_x = torch.roll(x, shifts=(-shift_lat, -shift_lat), dims=(1, 2)) |
|
|
x_windows = window_partition(shifted_x, self.window_size, ndim=2) |
|
|
|
|
|
else: |
|
|
shifted_x = x |
|
|
x_windows = window_partition(shifted_x, self.window_size, ndim=2) |
|
|
|
|
|
|
|
|
win_lat, win_lon = self.window_size |
|
|
x_windows = x_windows.view( |
|
|
x_windows.shape[0], x_windows.shape[1], win_lat * win_lon, C |
|
|
) |
|
|
|
|
|
|
|
|
attn_windows = self.attn( |
|
|
x_windows, mask=self.attn_mask |
|
|
) |
|
|
|
|
|
attn_windows = attn_windows.view( |
|
|
attn_windows.shape[0], attn_windows.shape[1], win_lat, win_lon, C |
|
|
) |
|
|
|
|
|
if self.roll: |
|
|
shifted_x = window_reverse( |
|
|
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2 |
|
|
) |
|
|
|
|
|
x = torch.roll(shifted_x, shifts=(shift_lat, shift_lon), dims=(1, 2)) |
|
|
else: |
|
|
shifted_x = window_reverse( |
|
|
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2 |
|
|
) |
|
|
x = shifted_x |
|
|
|
|
|
|
|
|
x = crop2d(x.permute(0, 3, 1, 2), self.input_resolution).permute(0, 2, 3, 1) |
|
|
|
|
|
x = x.reshape(B, Lat * Lon, C) |
|
|
x = shortcut + self.drop_path(x) |
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FuserLayer(nn.Module): |
|
|
"""Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn |
|
|
A basic 3D Transformer layer for one stage |
|
|
|
|
|
Args: |
|
|
dim (int): Number of input channels. |
|
|
input_resolution (tuple[int]): Input resolution. |
|
|
depth (int): Number of blocks. |
|
|
num_heads (int): Number of attention heads. |
|
|
window_size (tuple[int]): Local window size. |
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
input_resolution, |
|
|
depth, |
|
|
num_heads, |
|
|
window_size, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.input_resolution = input_resolution |
|
|
self.depth = depth |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
Transformer3DBlock( |
|
|
dim=dim, |
|
|
input_resolution=input_resolution, |
|
|
num_heads=num_heads, |
|
|
window_size=window_size, |
|
|
shift_size=(0, 0, 0) if i % 2 == 0 else None, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path[i] |
|
|
if isinstance(drop_path, Sequence) |
|
|
else drop_path, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
|
"""A 2D Transformer Encoder Module for one stage |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): image size(Lat, Lon). |
|
|
patch_size (tuple[int]): Patch token size of Patch Embedding. |
|
|
in_chans (int): number of input channels of Patch Embedding. |
|
|
dim (int): Number of input channels of transformer. |
|
|
input_resolution (tuple[int]): Input resolution for transformer before downsampling. |
|
|
middle_resolution (tuple[int]): Input resolution for transformer after downsampling. |
|
|
depth (int): Number of blocks for transformer before downsampling. |
|
|
depth_middle (int): Number of blocks for transformer after downsampling. |
|
|
num_heads (int): Number of attention heads. |
|
|
window_size (tuple[int]): Local window size. |
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
img_size, |
|
|
patch_size, |
|
|
in_chans, |
|
|
dim, |
|
|
input_resolution, |
|
|
middle_resolution, |
|
|
depth, |
|
|
depth_middle, |
|
|
num_heads, |
|
|
window_size, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_chans = in_chans |
|
|
self.dim = dim |
|
|
self.input_resolution = input_resolution |
|
|
self.depth = depth |
|
|
self.depth_middle = depth_middle |
|
|
if isinstance(drop_path, Sequence): |
|
|
drop_path_middle = drop_path[depth:] |
|
|
drop_path = drop_path[:depth] |
|
|
else: |
|
|
drop_path_middle = drop_path |
|
|
if isinstance(num_heads, Sequence): |
|
|
num_heads_middle = num_heads[1] |
|
|
num_heads = num_heads[0] |
|
|
else: |
|
|
num_heads_middle = num_heads |
|
|
|
|
|
self.patchembed2d = PatchEmbed2D( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=in_chans, |
|
|
embed_dim=dim, |
|
|
) |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
Transformer2DBlock( |
|
|
dim=dim, |
|
|
input_resolution=input_resolution, |
|
|
num_heads=num_heads, |
|
|
window_size=window_size, |
|
|
shift_size=(0, 0) if i % 2 == 0 else None, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path[i] |
|
|
if isinstance(drop_path, Sequence) |
|
|
else drop_path, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.downsample = DownSample2D( |
|
|
in_dim=dim, |
|
|
input_resolution=input_resolution, |
|
|
output_resolution=middle_resolution, |
|
|
) |
|
|
|
|
|
self.blocks_middle = nn.ModuleList( |
|
|
[ |
|
|
Transformer2DBlock( |
|
|
dim=dim * 2, |
|
|
input_resolution=middle_resolution, |
|
|
num_heads=num_heads_middle, |
|
|
window_size=window_size, |
|
|
shift_size=(0, 0) if i % 2 == 0 else None, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path_middle[i] |
|
|
if isinstance(drop_path_middle, Sequence) |
|
|
else drop_path_middle, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for i in range(depth_middle) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.patchembed2d(x) |
|
|
B, C, Lat, Lon = x.shape |
|
|
x = x.reshape(B, C, -1).transpose(1, 2) |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
skip = x.reshape(B, Lat, Lon, C) |
|
|
x = self.downsample(x) |
|
|
for blk in self.blocks_middle: |
|
|
x = blk(x) |
|
|
return x, skip |
|
|
|
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
|
"""A 2D Transformer Decoder Module for one stage |
|
|
|
|
|
Args: |
|
|
img_size (tuple[int]): image size(Lat, Lon). |
|
|
patch_size (tuple[int]): Patch token size of Patch Recovery. |
|
|
out_chans (int): number of output channels of Patch Recovery. |
|
|
dim (int): Number of input channels of transformer. |
|
|
output_resolution (tuple[int]): Input resolution for transformer after upsampling. |
|
|
middle_resolution (tuple[int]): Input resolution for transformer before upsampling. |
|
|
depth (int): Number of blocks for transformer after upsampling. |
|
|
depth_middle (int): Number of blocks for transformer before upsampling. |
|
|
num_heads (int): Number of attention heads. |
|
|
window_size (tuple[int]): Local window size. |
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|
|
drop (float, optional): Dropout rate. Default: 0.0 |
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
img_size, |
|
|
patch_size, |
|
|
out_chans, |
|
|
dim, |
|
|
output_resolution, |
|
|
middle_resolution, |
|
|
depth, |
|
|
depth_middle, |
|
|
num_heads, |
|
|
window_size, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.out_chans = out_chans |
|
|
self.dim = dim |
|
|
self.output_resolution = output_resolution |
|
|
self.depth = depth |
|
|
self.depth_middle = depth_middle |
|
|
if isinstance(drop_path, Sequence): |
|
|
drop_path_middle = drop_path[depth:] |
|
|
drop_path = drop_path[:depth] |
|
|
else: |
|
|
drop_path_middle = drop_path |
|
|
if isinstance(num_heads, Sequence): |
|
|
num_heads_middle = num_heads[1] |
|
|
num_heads = num_heads[0] |
|
|
else: |
|
|
num_heads_middle = num_heads |
|
|
|
|
|
self.blocks_middle = nn.ModuleList( |
|
|
[ |
|
|
Transformer2DBlock( |
|
|
dim=dim * 2, |
|
|
input_resolution=middle_resolution, |
|
|
num_heads=num_heads_middle, |
|
|
window_size=window_size, |
|
|
shift_size=(0, 0) if i % 2 == 0 else None, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path_middle[i] |
|
|
if isinstance(drop_path_middle, Sequence) |
|
|
else drop_path_middle, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for i in range(depth_middle) |
|
|
] |
|
|
) |
|
|
|
|
|
self.upsample = UpSample2D( |
|
|
in_dim=dim * 2, |
|
|
out_dim=dim, |
|
|
input_resolution=middle_resolution, |
|
|
output_resolution=output_resolution, |
|
|
) |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
Transformer2DBlock( |
|
|
dim=dim, |
|
|
input_resolution=output_resolution, |
|
|
num_heads=num_heads, |
|
|
window_size=window_size, |
|
|
shift_size=(0, 0) if i % 2 == 0 else None, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path[i] |
|
|
if isinstance(drop_path, Sequence) |
|
|
else drop_path, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.patchrecovery2d = PatchRecovery2D(img_size, patch_size, 2 * dim, out_chans) |
|
|
|
|
|
def forward(self, x, skip): |
|
|
B, Lat, Lon, C = skip.shape |
|
|
for blk in self.blocks_middle: |
|
|
x = blk(x) |
|
|
x = self.upsample(x) |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
output = torch.cat([x, skip.reshape(B, -1, C)], dim=-1) |
|
|
output = output.transpose(1, 2).reshape(B, -1, Lat, Lon) |
|
|
output = self.patchrecovery2d(output) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Fengwu(nn.Module): |
|
|
""" |
|
|
FengWu PyTorch impl of: `FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead` |
|
|
- https://arxiv.org/pdf/2304.02948.pdf |
|
|
|
|
|
Args: |
|
|
img_size: Image size(Lat, Lon). Default: (721,1440) |
|
|
pressure_level: Number of pressure_level. Default: 37 |
|
|
embed_dim (int): Patch embedding dimension. Default: 192 |
|
|
patch_size (tuple[int]): Patch token size. Default: (4,4) |
|
|
num_heads (tuple[int]): Number of attention heads in different layers. |
|
|
window_size (tuple[int]): Window size. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
in_shape=(1, 69, 120, 240), |
|
|
pressure_level=13, |
|
|
embed_dim=192, |
|
|
patch_size=(4, 4), |
|
|
num_heads=(6, 12, 12, 6), |
|
|
window_size=(2, 6, 12), |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
img_size = in_shape[2],in_shape[3] |
|
|
drop_path = np.linspace(0, 0.2, 8).tolist() |
|
|
drop_path_fuser = [0.2] * 6 |
|
|
resolution_down1 = ( |
|
|
math.ceil(img_size[0] / patch_size[0]), |
|
|
math.ceil(img_size[1] / patch_size[1]), |
|
|
) |
|
|
resolution_down2 = ( |
|
|
math.ceil(resolution_down1[0] / 2), |
|
|
math.ceil(resolution_down1[1] / 2), |
|
|
) |
|
|
resolution = (resolution_down1, resolution_down2) |
|
|
self.encoder_surface = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=4, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.encoder_z = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.encoder_r = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.encoder_u = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.encoder_v = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.encoder_t = EncoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
in_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
input_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
|
|
|
self.fuser = FuserLayer( |
|
|
dim=embed_dim * 2, |
|
|
input_resolution=(6, resolution[1][0], resolution[1][1]), |
|
|
depth=6, |
|
|
num_heads=num_heads[1], |
|
|
window_size=window_size, |
|
|
drop_path=drop_path_fuser, |
|
|
) |
|
|
|
|
|
self.decoder_surface = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=4, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.decoder_z = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.decoder_r = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.decoder_u = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.decoder_v = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
self.decoder_t = DecoderLayer( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
out_chans=pressure_level, |
|
|
dim=embed_dim, |
|
|
output_resolution=resolution[0], |
|
|
middle_resolution=resolution[1], |
|
|
depth=2, |
|
|
depth_middle=6, |
|
|
num_heads=num_heads[:2], |
|
|
window_size=window_size[1:], |
|
|
drop_path=drop_path, |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
surface (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=4. |
|
|
z (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37. |
|
|
r (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37. |
|
|
u (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37. |
|
|
v (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37. |
|
|
t (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37. |
|
|
""" |
|
|
surface = x[:, 65:69, :, :] |
|
|
u = x[:, 39:52, :, :] |
|
|
v = x[:, 52:65, :, :] |
|
|
t = x[:, 26:39, :, :] |
|
|
r = x[:, 13:26, :, :] |
|
|
z = x[:, 0:13, :, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
surface, skip_surface = self.encoder_surface(surface) |
|
|
z, skip_z = self.encoder_z(z) |
|
|
r, skip_r = self.encoder_r(r) |
|
|
u, skip_u = self.encoder_u(u) |
|
|
v, skip_v = self.encoder_v(v) |
|
|
t, skip_t = self.encoder_t(t) |
|
|
|
|
|
x = torch.cat( |
|
|
[ |
|
|
surface.unsqueeze(1), |
|
|
z.unsqueeze(1), |
|
|
r.unsqueeze(1), |
|
|
u.unsqueeze(1), |
|
|
v.unsqueeze(1), |
|
|
t.unsqueeze(1), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
B, PL, L_SIZE, C = x.shape |
|
|
x = x.reshape(B, -1, C) |
|
|
x = self.fuser(x) |
|
|
|
|
|
x = x.reshape(B, PL, L_SIZE, C) |
|
|
surface, z, r, u, v, t = ( |
|
|
x[:, 0, :, :], |
|
|
x[:, 1, :, :], |
|
|
x[:, 2, :, :], |
|
|
x[:, 3, :, :], |
|
|
x[:, 4, :, :], |
|
|
x[:, 5, :, :], |
|
|
) |
|
|
|
|
|
surface = self.decoder_surface(surface, skip_surface) |
|
|
z = self.decoder_z(z, skip_z) |
|
|
r = self.decoder_r(r, skip_r) |
|
|
u = self.decoder_u(u, skip_u) |
|
|
v = self.decoder_v(v, skip_v) |
|
|
t = self.decoder_t(t, skip_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reshaped =torch.cat((z,r, t,u,v),dim=1) |
|
|
final_output = torch.cat([reshaped, surface], dim=1) |
|
|
return final_output |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
inputs = torch.randn(1, 69, 120, 240) |
|
|
model = Fengwu() |
|
|
output = model(inputs) |
|
|
print(inputs.shape) |
|
|
print(output.shape) |
|
|
|