sobinalosious92's picture
Upload 297 files
930ea3d verified
# conv.py
# Clean, dependency-light graph encoder blocks for molecular GNNs.
# - Single source of truth for convolution choices: "gine", "gin", "gcn"
# - Edge attributes are supported for "gine" (recommended for chemistry)
# - No duplication with PyG built-ins; everything wraps torch_geometric.nn
# - Consistent encoder API: GNNEncoder(...).forward(x, edge_index, edge_attr, batch) -> graph embedding [B, emb_dim]
from __future__ import annotations
from typing import Literal, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
GINEConv,
GINConv,
GCNConv,
global_mean_pool,
global_add_pool,
global_max_pool,
)
def get_activation(name: str) -> nn.Module:
name = name.lower()
if name == "relu":
return nn.ReLU()
if name == "gelu":
return nn.GELU()
if name == "silu":
return nn.SiLU()
if name in ("leaky_relu", "lrelu"):
return nn.LeakyReLU(0.1)
raise ValueError(f"Unknown activation: {name}")
class MLP(nn.Module):
"""Small MLP used inside GNN layers and projections."""
def __init__(
self,
in_dim: int,
hidden_dim: int,
out_dim: int,
num_layers: int = 2,
act: str = "relu",
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
assert num_layers >= 1
layers: list[nn.Module] = []
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
if i < len(dims) - 2:
layers.append(get_activation(act))
if dropout > 0:
layers.append(nn.Dropout(dropout))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class NodeProjector(nn.Module):
"""Projects raw node features to model embedding size."""
def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"):
super().__init__()
if in_dim_node == emb_dim:
self.proj = nn.Identity()
else:
self.proj = nn.Sequential(
nn.Linear(in_dim_node, emb_dim),
get_activation(act),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class EdgeProjector(nn.Module):
"""Projects raw edge attributes to model embedding size for GINE."""
def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"):
super().__init__()
if in_dim_edge <= 0:
raise ValueError("in_dim_edge must be > 0 when using edge attributes")
self.proj = nn.Sequential(
nn.Linear(in_dim_edge, emb_dim),
get_activation(act),
)
def forward(self, e: torch.Tensor) -> torch.Tensor:
return self.proj(e)
class GNNEncoder(nn.Module):
"""
Backbone GNN with selectable conv type.
gnn_type:
- "gine": chemistry-ready, uses edge_attr (recommended)
- "gin" : ignores edge_attr, strong node MPNN
- "gcn" : ignores edge_attr, fast spectral conv
norm: "batch" | "layer" | "none"
readout: "mean" | "sum" | "max"
"""
def __init__(
self,
in_dim_node: int,
emb_dim: int,
num_layers: int = 5,
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
in_dim_edge: int = 0,
act: str = "relu",
dropout: float = 0.0,
residual: bool = True,
norm: Literal["batch", "layer", "none"] = "batch",
readout: Literal["mean", "sum", "max"] = "mean",
):
super().__init__()
assert num_layers >= 1
self.gnn_type = gnn_type.lower()
self.emb_dim = emb_dim
self.num_layers = num_layers
self.residual = residual
self.dropout_p = float(dropout)
self.readout = readout.lower()
self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act)
self.edge_proj: Optional[EdgeProjector] = None
if self.gnn_type == "gine":
if in_dim_edge <= 0:
raise ValueError(
"gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type."
)
self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act)
# Build conv stack
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
for _ in range(num_layers):
if self.gnn_type == "gine":
# edge_attr must be projected to emb_dim
nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
conv = GINEConv(nn_mlp)
elif self.gnn_type == "gin":
nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
conv = GINConv(nn_mlp)
elif self.gnn_type == "gcn":
conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True)
else:
raise ValueError(f"Unknown gnn_type: {gnn_type}")
self.convs.append(conv)
if norm == "batch":
self.norms.append(nn.BatchNorm1d(emb_dim))
elif norm == "layer":
self.norms.append(nn.LayerNorm(emb_dim))
elif norm == "none":
self.norms.append(nn.Identity())
else:
raise ValueError(f"Unknown norm: {norm}")
self.act = get_activation(act)
def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
if self.readout == "mean":
return global_mean_pool(x, batch)
if self.readout == "sum":
return global_add_pool(x, batch)
if self.readout == "max":
return global_max_pool(x, batch)
raise ValueError(f"Unknown readout: {self.readout}")
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor],
batch: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Returns a graph-level embedding of shape [B, emb_dim].
If batch is None, assumes a single graph and creates a zero batch vector.
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
# Project features (ensure float dtype)
x = x.float()
x = self.node_proj(x)
e = None
if self.gnn_type == "gine":
if edge_attr is None:
raise ValueError("GINE requires edge_attr, but got None.")
e = self.edge_proj(edge_attr.float())
# Message passing
h = x
for conv, norm in zip(self.convs, self.norms):
if self.gnn_type == "gcn":
h_next = conv(h, edge_index) # GCNConv ignores edge_attr
elif self.gnn_type == "gin":
h_next = conv(h, edge_index) # GINConv ignores edge_attr
else: # gine
h_next = conv(h, edge_index, e)
h_next = norm(h_next)
h_next = self.act(h_next)
if self.residual and h_next.shape == h.shape:
h = h + h_next
else:
h = h_next
if self.dropout_p > 0:
h = F.dropout(h, p=self.dropout_p, training=self.training)
g = self._readout(h, batch)
return g # [B, emb_dim]
def build_gnn_encoder(
in_dim_node: int,
emb_dim: int,
num_layers: int = 5,
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
in_dim_edge: int = 0,
act: str = "relu",
dropout: float = 0.0,
residual: bool = True,
norm: Literal["batch", "layer", "none"] = "batch",
readout: Literal["mean", "sum", "max"] = "mean",
) -> GNNEncoder:
"""
Factory to create a GNNEncoder with a consistent, minimal API.
Prefer calling this from model.py so encoder construction is centralized.
"""
return GNNEncoder(
in_dim_node=in_dim_node,
emb_dim=emb_dim,
num_layers=num_layers,
gnn_type=gnn_type,
in_dim_edge=in_dim_edge,
act=act,
dropout=dropout,
residual=residual,
norm=norm,
readout=readout,
)
__all__ = ["GNNEncoder", "build_gnn_encoder"]