from __future__ import annotations import torch import torch.nn as nn from torch_geometric.nn import global_mean_pool from torch_geometric.utils import degree class SpectralFilterLayer(nn.Module): def __init__(self, in_ch: int, out_ch: int, K: int = 3): super().__init__() self.coeffs = nn.ParameterList( [nn.Parameter(torch.randn(in_ch, out_ch) * 0.01) for _ in range(K)] ) self.K = K def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: out = x @ self.coeffs[0] x_k = x for k in range(1, self.K): row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype).clamp(min=1) norm = deg.pow(-0.5) aggr = torch.zeros_like(x) aggr.index_add_( 0, col, norm[col].unsqueeze(-1) * x_k[row] * norm[row].unsqueeze(-1), ) x_k = aggr out = out + x_k @ self.coeffs[k] return torch.relu(out) class TemporalDiffModule(nn.Module): def __init__(self, T: int, out_dim: int = 32): super().__init__() self.proj = nn.Linear(T, out_dim) def forward(self, x_seq: torch.Tensor) -> torch.Tensor: fft = torch.fft.fft(x_seq, dim=1).abs() fft_pooled = fft.mean(dim=-1) return self.proj(fft_pooled) class SSTGNN(nn.Module): def __init__( self, patch_feat_dim: int = 8, hidden_dim: int = 128, num_frames: int = 32, num_spectral_layers: int = 3, spectral_K: int = 3, fft_dim: int = 32, ): super().__init__() self.input_proj = nn.Linear(patch_feat_dim + fft_dim, hidden_dim) self.spectral_layers = nn.ModuleList( [ SpectralFilterLayer(hidden_dim, hidden_dim, K=spectral_K) for _ in range(num_spectral_layers) ] ) self.temporal = TemporalDiffModule(T=num_frames, out_dim=fft_dim) self.classifier = nn.Sequential( nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 1), ) def forward(self, data): fft_feat = self.temporal(data.x_temporal) x = torch.cat([data.x, fft_feat], dim=-1) x = self.input_proj(x) for layer in self.spectral_layers: x = layer(x, data.edge_index) + x x = global_mean_pool(x, data.batch) return self.classifier(x).squeeze(-1)