|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
|
def __init__(self, d_model):
|
|
|
super().__init__()
|
|
|
self.d_model = d_model
|
|
|
self.time_encoding = nn.Sequential(
|
|
|
nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
|
|
|
groups=d_model),
|
|
|
)
|
|
|
|
|
|
self.proj_in = nn.Sequential(
|
|
|
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
|
|
|
nn.GroupNorm(8, 64),
|
|
|
nn.GELU(),
|
|
|
|
|
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
|
|
nn.GroupNorm(8, 128),
|
|
|
nn.GELU(),
|
|
|
|
|
|
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
|
|
nn.GroupNorm(8, 64),
|
|
|
nn.GELU(),
|
|
|
)
|
|
|
self.spectral_proj = nn.Sequential(
|
|
|
nn.Linear(101, d_model),
|
|
|
nn.Dropout(0.1),
|
|
|
)
|
|
|
|
|
|
self.num_channels = 19
|
|
|
self.channel_embedding = nn.Linear(self.num_channels, d_model)
|
|
|
|
|
|
def forward(self, x):
|
|
|
bz, ch_num, patch_num, patch_size = x.shape
|
|
|
channel_in = torch.arange(self.num_channels+1).cuda()
|
|
|
|
|
|
x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
|
|
|
patch_emb = self.proj_in(x)
|
|
|
patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
|
|
|
|
|
|
x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
|
|
|
spectral = torch.fft.rfft(x, dim=-1, norm='forward')
|
|
|
spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
|
|
|
spectral_emb = self.spectral_proj(spectral)
|
|
|
|
|
|
patch_emb = patch_emb + spectral_emb
|
|
|
|
|
|
channel_embeddings = []
|
|
|
start_idx = 0
|
|
|
|
|
|
group_channels = channel_in[start_idx:start_idx + ch_num]
|
|
|
group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
|
|
|
group_emb = self.channel_embedding(group_one_hot)
|
|
|
group_emb = group_emb.unsqueeze(0).unsqueeze(2)
|
|
|
group_emb = group_emb.expand(bz, -1, patch_num, -1)
|
|
|
channel_embeddings.append(group_emb)
|
|
|
start_idx += ch_num
|
|
|
|
|
|
channel_pos = torch.cat(channel_embeddings, dim=0)
|
|
|
|
|
|
patch_emb = patch_emb + channel_pos
|
|
|
|
|
|
time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
|
|
|
time_embedding = time_embedding.permute(0, 2, 3, 1)
|
|
|
|
|
|
patch_emb = patch_emb + time_embedding
|
|
|
|
|
|
return patch_emb
|
|
|
|