File size: 8,385 Bytes
930ea3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# conv.py
# Clean, dependency-light graph encoder blocks for molecular GNNs.
# - Single source of truth for convolution choices: "gine", "gin", "gcn"
# - Edge attributes are supported for "gine" (recommended for chemistry)
# - No duplication with PyG built-ins; everything wraps torch_geometric.nn
# - Consistent encoder API: GNNEncoder(...).forward(x, edge_index, edge_attr, batch) -> graph embedding [B, emb_dim]

from __future__ import annotations
from typing import Literal, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    GINEConv,
    GINConv,
    GCNConv,
    global_mean_pool,
    global_add_pool,
    global_max_pool,
)


def get_activation(name: str) -> nn.Module:
    name = name.lower()
    if name == "relu":
        return nn.ReLU()
    if name == "gelu":
        return nn.GELU()
    if name == "silu":
        return nn.SiLU()
    if name in ("leaky_relu", "lrelu"):
        return nn.LeakyReLU(0.1)
    raise ValueError(f"Unknown activation: {name}")


class MLP(nn.Module):
    """Small MLP used inside GNN layers and projections."""
    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
        out_dim: int,
        num_layers: int = 2,
        act: str = "relu",
        dropout: float = 0.0,
        bias: bool = True,
    ):
        super().__init__()
        assert num_layers >= 1
        layers: list[nn.Module] = []
        dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
            if i < len(dims) - 2:
                layers.append(get_activation(act))
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class NodeProjector(nn.Module):
    """Projects raw node features to model embedding size."""
    def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"):
        super().__init__()
        if in_dim_node == emb_dim:
            self.proj = nn.Identity()
        else:
            self.proj = nn.Sequential(
                nn.Linear(in_dim_node, emb_dim),
                get_activation(act),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)


class EdgeProjector(nn.Module):
    """Projects raw edge attributes to model embedding size for GINE."""
    def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"):
        super().__init__()
        if in_dim_edge <= 0:
            raise ValueError("in_dim_edge must be > 0 when using edge attributes")
        self.proj = nn.Sequential(
            nn.Linear(in_dim_edge, emb_dim),
            get_activation(act),
        )

    def forward(self, e: torch.Tensor) -> torch.Tensor:
        return self.proj(e)


class GNNEncoder(nn.Module):
    """
    Backbone GNN with selectable conv type.

    gnn_type:
        - "gine": chemistry-ready, uses edge_attr (recommended)
        - "gin" : ignores edge_attr, strong node MPNN
        - "gcn" : ignores edge_attr, fast spectral conv
    norm: "batch" | "layer" | "none"
    readout: "mean" | "sum" | "max"
    """

    def __init__(
        self,
        in_dim_node: int,
        emb_dim: int,
        num_layers: int = 5,
        gnn_type: Literal["gine", "gin", "gcn"] = "gine",
        in_dim_edge: int = 0,
        act: str = "relu",
        dropout: float = 0.0,
        residual: bool = True,
        norm: Literal["batch", "layer", "none"] = "batch",
        readout: Literal["mean", "sum", "max"] = "mean",
    ):
        super().__init__()
        assert num_layers >= 1

        self.gnn_type = gnn_type.lower()
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.residual = residual
        self.dropout_p = float(dropout)
        self.readout = readout.lower()

        self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act)
        self.edge_proj: Optional[EdgeProjector] = None

        if self.gnn_type == "gine":
            if in_dim_edge <= 0:
                raise ValueError(
                    "gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type."
                )
            self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act)

        # Build conv stack
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            if self.gnn_type == "gine":
                # edge_attr must be projected to emb_dim
                nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
                conv = GINEConv(nn_mlp)
            elif self.gnn_type == "gin":
                nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
                conv = GINConv(nn_mlp)
            elif self.gnn_type == "gcn":
                conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True)
            else:
                raise ValueError(f"Unknown gnn_type: {gnn_type}")
            self.convs.append(conv)

            if norm == "batch":
                self.norms.append(nn.BatchNorm1d(emb_dim))
            elif norm == "layer":
                self.norms.append(nn.LayerNorm(emb_dim))
            elif norm == "none":
                self.norms.append(nn.Identity())
            else:
                raise ValueError(f"Unknown norm: {norm}")

        self.act = get_activation(act)

    def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        if self.readout == "mean":
            return global_mean_pool(x, batch)
        if self.readout == "sum":
            return global_add_pool(x, batch)
        if self.readout == "max":
            return global_max_pool(x, batch)
        raise ValueError(f"Unknown readout: {self.readout}")

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor],
        batch: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Returns a graph-level embedding of shape [B, emb_dim].
        If batch is None, assumes a single graph and creates a zero batch vector.
        """
        if batch is None:
            batch = x.new_zeros(x.size(0), dtype=torch.long)

        # Project features (ensure float dtype)
        x = x.float()
        x = self.node_proj(x)

        e = None
        if self.gnn_type == "gine":
            if edge_attr is None:
                raise ValueError("GINE requires edge_attr, but got None.")
            e = self.edge_proj(edge_attr.float())

        # Message passing
        h = x
        for conv, norm in zip(self.convs, self.norms):
            if self.gnn_type == "gcn":
                h_next = conv(h, edge_index)  # GCNConv ignores edge_attr
            elif self.gnn_type == "gin":
                h_next = conv(h, edge_index)  # GINConv ignores edge_attr
            else:  # gine
                h_next = conv(h, edge_index, e)

            h_next = norm(h_next)
            h_next = self.act(h_next)

            if self.residual and h_next.shape == h.shape:
                h = h + h_next
            else:
                h = h_next

            if self.dropout_p > 0:
                h = F.dropout(h, p=self.dropout_p, training=self.training)

        g = self._readout(h, batch)
        return g  # [B, emb_dim]


def build_gnn_encoder(
    in_dim_node: int,
    emb_dim: int,
    num_layers: int = 5,
    gnn_type: Literal["gine", "gin", "gcn"] = "gine",
    in_dim_edge: int = 0,
    act: str = "relu",
    dropout: float = 0.0,
    residual: bool = True,
    norm: Literal["batch", "layer", "none"] = "batch",
    readout: Literal["mean", "sum", "max"] = "mean",
) -> GNNEncoder:
    """
    Factory to create a GNNEncoder with a consistent, minimal API.
    Prefer calling this from model.py so encoder construction is centralized.
    """
    return GNNEncoder(
        in_dim_node=in_dim_node,
        emb_dim=emb_dim,
        num_layers=num_layers,
        gnn_type=gnn_type,
        in_dim_edge=in_dim_edge,
        act=act,
        dropout=dropout,
        residual=residual,
        norm=norm,
        readout=readout,
    )


__all__ = ["GNNEncoder", "build_gnn_encoder"]