from __future__ import annotations from typing import Literal import torch import torch.nn as nn from torch_geometric.nn import GINEConv, GPSConv class GraphSpatialEncoder(nn.Module): """Per-frame spatial graph encoder built on top of PyG.""" def __init__( self, node_in_dim: int, edge_in_dim: int, hidden_dim: int = 256, num_layers: int = 4, dropout: float = 0.1, backbone: Literal["gine", "gps"] = "gine", num_heads: int = 8, ): super().__init__() self.backbone = backbone self.node_proj = nn.Linear(node_in_dim, hidden_dim) self.edge_proj = nn.Linear(edge_in_dim, hidden_dim) self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList() for _ in range(num_layers): mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), ) if backbone == "gine": conv = GINEConv(mlp, edge_dim=hidden_dim) elif backbone == "gps": conv = GPSConv( channels=hidden_dim, conv=GINEConv(mlp, edge_dim=hidden_dim), heads=num_heads, attn_type="multihead", attn_kwargs={"dropout": dropout}, ) else: raise ValueError(f"Unsupported graph backbone: {backbone}") self.layers.append(conv) self.norm = nn.LayerNorm(hidden_dim) def forward(self, data) -> torch.Tensor: x = self.node_proj(data.x) edge_attr = self.edge_proj(data.edge_attr) for layer in self.layers: residual = x if self.backbone == "gps": x = layer(x, data.edge_index, data.batch, edge_attr=edge_attr) else: x = layer(x, data.edge_index, edge_attr=edge_attr) x = self.norm(x + self.dropout(x) + residual) return x