File size: 2,533 Bytes
eff3d67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import degree


class SpectralFilterLayer(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, K: int = 3):
        super().__init__()
        self.coeffs = nn.ParameterList(
            [nn.Parameter(torch.randn(in_ch, out_ch) * 0.01) for _ in range(K)]
        )
        self.K = K

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        out = x @ self.coeffs[0]
        x_k = x
        for k in range(1, self.K):
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype).clamp(min=1)
            norm = deg.pow(-0.5)
            aggr = torch.zeros_like(x)
            aggr.index_add_(
                0,
                col,
                norm[col].unsqueeze(-1) * x_k[row] * norm[row].unsqueeze(-1),
            )
            x_k = aggr
            out = out + x_k @ self.coeffs[k]
        return torch.relu(out)


class TemporalDiffModule(nn.Module):
    def __init__(self, T: int, out_dim: int = 32):
        super().__init__()
        self.proj = nn.Linear(T, out_dim)

    def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
        fft = torch.fft.fft(x_seq, dim=1).abs()
        fft_pooled = fft.mean(dim=-1)
        return self.proj(fft_pooled)


class SSTGNN(nn.Module):
    def __init__(
        self,
        patch_feat_dim: int = 8,
        hidden_dim: int = 128,
        num_frames: int = 32,
        num_spectral_layers: int = 3,
        spectral_K: int = 3,
        fft_dim: int = 32,
    ):
        super().__init__()
        self.input_proj = nn.Linear(patch_feat_dim + fft_dim, hidden_dim)
        self.spectral_layers = nn.ModuleList(
            [
                SpectralFilterLayer(hidden_dim, hidden_dim, K=spectral_K)
                for _ in range(num_spectral_layers)
            ]
        )
        self.temporal = TemporalDiffModule(T=num_frames, out_dim=fft_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )

    def forward(self, data):
        fft_feat = self.temporal(data.x_temporal)
        x = torch.cat([data.x, fft_feat], dim=-1)
        x = self.input_proj(x)
        for layer in self.spectral_layers:
            x = layer(x, data.edge_index) + x
        x = global_mean_pool(x, data.batch)
        return self.classifier(x).squeeze(-1)