deepdetection / modules /sstgnn_model.py
akagtag's picture
Implement ZeroGPU Space runtime
eff3d67
from __future__ import annotations
import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import degree
class SpectralFilterLayer(nn.Module):
def __init__(self, in_ch: int, out_ch: int, K: int = 3):
super().__init__()
self.coeffs = nn.ParameterList(
[nn.Parameter(torch.randn(in_ch, out_ch) * 0.01) for _ in range(K)]
)
self.K = K
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
out = x @ self.coeffs[0]
x_k = x
for k in range(1, self.K):
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype).clamp(min=1)
norm = deg.pow(-0.5)
aggr = torch.zeros_like(x)
aggr.index_add_(
0,
col,
norm[col].unsqueeze(-1) * x_k[row] * norm[row].unsqueeze(-1),
)
x_k = aggr
out = out + x_k @ self.coeffs[k]
return torch.relu(out)
class TemporalDiffModule(nn.Module):
def __init__(self, T: int, out_dim: int = 32):
super().__init__()
self.proj = nn.Linear(T, out_dim)
def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
fft = torch.fft.fft(x_seq, dim=1).abs()
fft_pooled = fft.mean(dim=-1)
return self.proj(fft_pooled)
class SSTGNN(nn.Module):
def __init__(
self,
patch_feat_dim: int = 8,
hidden_dim: int = 128,
num_frames: int = 32,
num_spectral_layers: int = 3,
spectral_K: int = 3,
fft_dim: int = 32,
):
super().__init__()
self.input_proj = nn.Linear(patch_feat_dim + fft_dim, hidden_dim)
self.spectral_layers = nn.ModuleList(
[
SpectralFilterLayer(hidden_dim, hidden_dim, K=spectral_K)
for _ in range(num_spectral_layers)
]
)
self.temporal = TemporalDiffModule(T=num_frames, out_dim=fft_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(self, data):
fft_feat = self.temporal(data.x_temporal)
x = torch.cat([data.x, fft_feat], dim=-1)
x = self.input_proj(x)
for layer in self.spectral_layers:
x = layer(x, data.edge_index) + x
x = global_mean_pool(x, data.batch)
return self.classifier(x).squeeze(-1)