Spaces:
Paused
Paused
File size: 2,533 Bytes
eff3d67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | from __future__ import annotations
import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import degree
class SpectralFilterLayer(nn.Module):
def __init__(self, in_ch: int, out_ch: int, K: int = 3):
super().__init__()
self.coeffs = nn.ParameterList(
[nn.Parameter(torch.randn(in_ch, out_ch) * 0.01) for _ in range(K)]
)
self.K = K
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
out = x @ self.coeffs[0]
x_k = x
for k in range(1, self.K):
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype).clamp(min=1)
norm = deg.pow(-0.5)
aggr = torch.zeros_like(x)
aggr.index_add_(
0,
col,
norm[col].unsqueeze(-1) * x_k[row] * norm[row].unsqueeze(-1),
)
x_k = aggr
out = out + x_k @ self.coeffs[k]
return torch.relu(out)
class TemporalDiffModule(nn.Module):
def __init__(self, T: int, out_dim: int = 32):
super().__init__()
self.proj = nn.Linear(T, out_dim)
def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
fft = torch.fft.fft(x_seq, dim=1).abs()
fft_pooled = fft.mean(dim=-1)
return self.proj(fft_pooled)
class SSTGNN(nn.Module):
def __init__(
self,
patch_feat_dim: int = 8,
hidden_dim: int = 128,
num_frames: int = 32,
num_spectral_layers: int = 3,
spectral_K: int = 3,
fft_dim: int = 32,
):
super().__init__()
self.input_proj = nn.Linear(patch_feat_dim + fft_dim, hidden_dim)
self.spectral_layers = nn.ModuleList(
[
SpectralFilterLayer(hidden_dim, hidden_dim, K=spectral_K)
for _ in range(num_spectral_layers)
]
)
self.temporal = TemporalDiffModule(T=num_frames, out_dim=fft_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(self, data):
fft_feat = self.temporal(data.x_temporal)
x = torch.cat([data.x, fft_feat], dim=-1)
x = self.input_proj(x)
for layer in self.spectral_layers:
x = layer(x, data.edge_index) + x
x = global_mean_pool(x, data.batch)
return self.classifier(x).squeeze(-1)
|