Spaces:
Paused
Paused
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.nn import global_mean_pool | |
| from torch_geometric.utils import degree | |
| class SpectralFilterLayer(nn.Module): | |
| def __init__(self, in_ch: int, out_ch: int, K: int = 3): | |
| super().__init__() | |
| self.coeffs = nn.ParameterList( | |
| [nn.Parameter(torch.randn(in_ch, out_ch) * 0.01) for _ in range(K)] | |
| ) | |
| self.K = K | |
| def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: | |
| out = x @ self.coeffs[0] | |
| x_k = x | |
| for k in range(1, self.K): | |
| row, col = edge_index | |
| deg = degree(col, x.size(0), dtype=x.dtype).clamp(min=1) | |
| norm = deg.pow(-0.5) | |
| aggr = torch.zeros_like(x) | |
| aggr.index_add_( | |
| 0, | |
| col, | |
| norm[col].unsqueeze(-1) * x_k[row] * norm[row].unsqueeze(-1), | |
| ) | |
| x_k = aggr | |
| out = out + x_k @ self.coeffs[k] | |
| return torch.relu(out) | |
| class TemporalDiffModule(nn.Module): | |
| def __init__(self, T: int, out_dim: int = 32): | |
| super().__init__() | |
| self.proj = nn.Linear(T, out_dim) | |
| def forward(self, x_seq: torch.Tensor) -> torch.Tensor: | |
| fft = torch.fft.fft(x_seq, dim=1).abs() | |
| fft_pooled = fft.mean(dim=-1) | |
| return self.proj(fft_pooled) | |
| class SSTGNN(nn.Module): | |
| def __init__( | |
| self, | |
| patch_feat_dim: int = 8, | |
| hidden_dim: int = 128, | |
| num_frames: int = 32, | |
| num_spectral_layers: int = 3, | |
| spectral_K: int = 3, | |
| fft_dim: int = 32, | |
| ): | |
| super().__init__() | |
| self.input_proj = nn.Linear(patch_feat_dim + fft_dim, hidden_dim) | |
| self.spectral_layers = nn.ModuleList( | |
| [ | |
| SpectralFilterLayer(hidden_dim, hidden_dim, K=spectral_K) | |
| for _ in range(num_spectral_layers) | |
| ] | |
| ) | |
| self.temporal = TemporalDiffModule(T=num_frames, out_dim=fft_dim) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward(self, data): | |
| fft_feat = self.temporal(data.x_temporal) | |
| x = torch.cat([data.x, fft_feat], dim=-1) | |
| x = self.input_proj(x) | |
| for layer in self.spectral_layers: | |
| x = layer(x, data.edge_index) + x | |
| x = global_mean_pool(x, data.batch) | |
| return self.classifier(x).squeeze(-1) | |