| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| | from typing import Literal, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch_geometric.nn import ( |
| | GINEConv, |
| | GINConv, |
| | GCNConv, |
| | global_mean_pool, |
| | global_add_pool, |
| | global_max_pool, |
| | ) |
| |
|
| |
|
| | def get_activation(name: str) -> nn.Module: |
| | name = name.lower() |
| | if name == "relu": |
| | return nn.ReLU() |
| | if name == "gelu": |
| | return nn.GELU() |
| | if name == "silu": |
| | return nn.SiLU() |
| | if name in ("leaky_relu", "lrelu"): |
| | return nn.LeakyReLU(0.1) |
| | raise ValueError(f"Unknown activation: {name}") |
| |
|
| |
|
| | class MLP(nn.Module): |
| | """Small MLP used inside GNN layers and projections.""" |
| | def __init__( |
| | self, |
| | in_dim: int, |
| | hidden_dim: int, |
| | out_dim: int, |
| | num_layers: int = 2, |
| | act: str = "relu", |
| | dropout: float = 0.0, |
| | bias: bool = True, |
| | ): |
| | super().__init__() |
| | assert num_layers >= 1 |
| | layers: list[nn.Module] = [] |
| | dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] |
| | for i in range(len(dims) - 1): |
| | layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias)) |
| | if i < len(dims) - 2: |
| | layers.append(get_activation(act)) |
| | if dropout > 0: |
| | layers.append(nn.Dropout(dropout)) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.net(x) |
| |
|
| |
|
| | class NodeProjector(nn.Module): |
| | """Projects raw node features to model embedding size.""" |
| | def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"): |
| | super().__init__() |
| | if in_dim_node == emb_dim: |
| | self.proj = nn.Identity() |
| | else: |
| | self.proj = nn.Sequential( |
| | nn.Linear(in_dim_node, emb_dim), |
| | get_activation(act), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.proj(x) |
| |
|
| |
|
| | class EdgeProjector(nn.Module): |
| | """Projects raw edge attributes to model embedding size for GINE.""" |
| | def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"): |
| | super().__init__() |
| | if in_dim_edge <= 0: |
| | raise ValueError("in_dim_edge must be > 0 when using edge attributes") |
| | self.proj = nn.Sequential( |
| | nn.Linear(in_dim_edge, emb_dim), |
| | get_activation(act), |
| | ) |
| |
|
| | def forward(self, e: torch.Tensor) -> torch.Tensor: |
| | return self.proj(e) |
| |
|
| |
|
| | class GNNEncoder(nn.Module): |
| | """ |
| | Backbone GNN with selectable conv type. |
| | |
| | gnn_type: |
| | - "gine": chemistry-ready, uses edge_attr (recommended) |
| | - "gin" : ignores edge_attr, strong node MPNN |
| | - "gcn" : ignores edge_attr, fast spectral conv |
| | norm: "batch" | "layer" | "none" |
| | readout: "mean" | "sum" | "max" |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_dim_node: int, |
| | emb_dim: int, |
| | num_layers: int = 5, |
| | gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| | in_dim_edge: int = 0, |
| | act: str = "relu", |
| | dropout: float = 0.0, |
| | residual: bool = True, |
| | norm: Literal["batch", "layer", "none"] = "batch", |
| | readout: Literal["mean", "sum", "max"] = "mean", |
| | ): |
| | super().__init__() |
| | assert num_layers >= 1 |
| |
|
| | self.gnn_type = gnn_type.lower() |
| | self.emb_dim = emb_dim |
| | self.num_layers = num_layers |
| | self.residual = residual |
| | self.dropout_p = float(dropout) |
| | self.readout = readout.lower() |
| |
|
| | self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act) |
| | self.edge_proj: Optional[EdgeProjector] = None |
| |
|
| | if self.gnn_type == "gine": |
| | if in_dim_edge <= 0: |
| | raise ValueError( |
| | "gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type." |
| | ) |
| | self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act) |
| |
|
| | |
| | self.convs = nn.ModuleList() |
| | self.norms = nn.ModuleList() |
| |
|
| | for _ in range(num_layers): |
| | if self.gnn_type == "gine": |
| | |
| | nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) |
| | conv = GINEConv(nn_mlp) |
| | elif self.gnn_type == "gin": |
| | nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) |
| | conv = GINConv(nn_mlp) |
| | elif self.gnn_type == "gcn": |
| | conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True) |
| | else: |
| | raise ValueError(f"Unknown gnn_type: {gnn_type}") |
| | self.convs.append(conv) |
| |
|
| | if norm == "batch": |
| | self.norms.append(nn.BatchNorm1d(emb_dim)) |
| | elif norm == "layer": |
| | self.norms.append(nn.LayerNorm(emb_dim)) |
| | elif norm == "none": |
| | self.norms.append(nn.Identity()) |
| | else: |
| | raise ValueError(f"Unknown norm: {norm}") |
| |
|
| | self.act = get_activation(act) |
| |
|
| | def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: |
| | if self.readout == "mean": |
| | return global_mean_pool(x, batch) |
| | if self.readout == "sum": |
| | return global_add_pool(x, batch) |
| | if self.readout == "max": |
| | return global_max_pool(x, batch) |
| | raise ValueError(f"Unknown readout: {self.readout}") |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | edge_index: torch.Tensor, |
| | edge_attr: Optional[torch.Tensor], |
| | batch: Optional[torch.Tensor], |
| | ) -> torch.Tensor: |
| | """ |
| | Returns a graph-level embedding of shape [B, emb_dim]. |
| | If batch is None, assumes a single graph and creates a zero batch vector. |
| | """ |
| | if batch is None: |
| | batch = x.new_zeros(x.size(0), dtype=torch.long) |
| |
|
| | |
| | x = x.float() |
| | x = self.node_proj(x) |
| |
|
| | e = None |
| | if self.gnn_type == "gine": |
| | if edge_attr is None: |
| | raise ValueError("GINE requires edge_attr, but got None.") |
| | e = self.edge_proj(edge_attr.float()) |
| |
|
| | |
| | h = x |
| | for conv, norm in zip(self.convs, self.norms): |
| | if self.gnn_type == "gcn": |
| | h_next = conv(h, edge_index) |
| | elif self.gnn_type == "gin": |
| | h_next = conv(h, edge_index) |
| | else: |
| | h_next = conv(h, edge_index, e) |
| |
|
| | h_next = norm(h_next) |
| | h_next = self.act(h_next) |
| |
|
| | if self.residual and h_next.shape == h.shape: |
| | h = h + h_next |
| | else: |
| | h = h_next |
| |
|
| | if self.dropout_p > 0: |
| | h = F.dropout(h, p=self.dropout_p, training=self.training) |
| |
|
| | g = self._readout(h, batch) |
| | return g |
| |
|
| |
|
| | def build_gnn_encoder( |
| | in_dim_node: int, |
| | emb_dim: int, |
| | num_layers: int = 5, |
| | gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| | in_dim_edge: int = 0, |
| | act: str = "relu", |
| | dropout: float = 0.0, |
| | residual: bool = True, |
| | norm: Literal["batch", "layer", "none"] = "batch", |
| | readout: Literal["mean", "sum", "max"] = "mean", |
| | ) -> GNNEncoder: |
| | """ |
| | Factory to create a GNNEncoder with a consistent, minimal API. |
| | Prefer calling this from model.py so encoder construction is centralized. |
| | """ |
| | return GNNEncoder( |
| | in_dim_node=in_dim_node, |
| | emb_dim=emb_dim, |
| | num_layers=num_layers, |
| | gnn_type=gnn_type, |
| | in_dim_edge=in_dim_edge, |
| | act=act, |
| | dropout=dropout, |
| | residual=residual, |
| | norm=norm, |
| | readout=readout, |
| | ) |
| |
|
| |
|
| | __all__ = ["GNNEncoder", "build_gnn_encoder"] |
| |
|