File size: 2,049 Bytes
da7bf91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

from typing import Literal

import torch
import torch.nn as nn
from torch_geometric.nn import GINEConv, GPSConv


class GraphSpatialEncoder(nn.Module):
    """Per-frame spatial graph encoder built on top of PyG."""

    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 256,
        num_layers: int = 4,
        dropout: float = 0.1,
        backbone: Literal["gine", "gps"] = "gine",
        num_heads: int = 8,
    ):
        super().__init__()
        self.backbone = backbone
        self.node_proj = nn.Linear(node_in_dim, hidden_dim)
        self.edge_proj = nn.Linear(edge_in_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            if backbone == "gine":
                conv = GINEConv(mlp, edge_dim=hidden_dim)
            elif backbone == "gps":
                conv = GPSConv(
                    channels=hidden_dim,
                    conv=GINEConv(mlp, edge_dim=hidden_dim),
                    heads=num_heads,
                    attn_type="multihead",
                    attn_kwargs={"dropout": dropout},
                )
            else:
                raise ValueError(f"Unsupported graph backbone: {backbone}")
            self.layers.append(conv)

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, data) -> torch.Tensor:
        x = self.node_proj(data.x)
        edge_attr = self.edge_proj(data.edge_attr)

        for layer in self.layers:
            residual = x
            if self.backbone == "gps":
                x = layer(x, data.edge_index, data.batch, edge_attr=edge_attr)
            else:
                x = layer(x, data.edge_index, edge_attr=edge_attr)
            x = self.norm(x + self.dropout(x) + residual)

        return x