gnn_wm / Ctrl-World-Graph /graphwm /models /graph_encoder_pyg.py
EndeavourDD's picture
Add files using upload-large-folder tool
da7bf91 verified
from __future__ import annotations
from typing import Literal
import torch
import torch.nn as nn
from torch_geometric.nn import GINEConv, GPSConv
class GraphSpatialEncoder(nn.Module):
"""Per-frame spatial graph encoder built on top of PyG."""
def __init__(
self,
node_in_dim: int,
edge_in_dim: int,
hidden_dim: int = 256,
num_layers: int = 4,
dropout: float = 0.1,
backbone: Literal["gine", "gps"] = "gine",
num_heads: int = 8,
):
super().__init__()
self.backbone = backbone
self.node_proj = nn.Linear(node_in_dim, hidden_dim)
self.edge_proj = nn.Linear(edge_in_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
for _ in range(num_layers):
mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
if backbone == "gine":
conv = GINEConv(mlp, edge_dim=hidden_dim)
elif backbone == "gps":
conv = GPSConv(
channels=hidden_dim,
conv=GINEConv(mlp, edge_dim=hidden_dim),
heads=num_heads,
attn_type="multihead",
attn_kwargs={"dropout": dropout},
)
else:
raise ValueError(f"Unsupported graph backbone: {backbone}")
self.layers.append(conv)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, data) -> torch.Tensor:
x = self.node_proj(data.x)
edge_attr = self.edge_proj(data.edge_attr)
for layer in self.layers:
residual = x
if self.backbone == "gps":
x = layer(x, data.edge_index, data.batch, edge_attr=edge_attr)
else:
x = layer(x, data.edge_index, edge_attr=edge_attr)
x = self.norm(x + self.dropout(x) + residual)
return x