netfm-training / netfm /models /encoder.py
henribonamy's picture
Upload netfm/models/encoder.py with huggingface_hub
197b2d3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
class GraphSAGEEncoder(nn.Module):
"""GraphSAGE-based encoder for learning node representations."""
def __init__(
self,
in_channels: int = 6,
hidden_channels: int = 256,
out_channels: int = 128,
num_layers: int = 3,
dropout: float = 0.1,
) -> None:
"""Initialize GraphSAGE encoder with configurable depth and width."""
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropout = dropout
self.convs.append(SAGEConv(in_channels, hidden_channels))
self.norms.append(nn.LayerNorm(hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.norms.append(nn.LayerNorm(hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
self.norms.append(nn.LayerNorm(out_channels))
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
"""Encode node features into embeddings."""
for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
x = conv(x, edge_index)
x = norm(x)
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x
def graph_embedding(
self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor
) -> torch.Tensor:
"""Compute graph-level embedding via mean pooling."""
node_embs = self.forward(x, edge_index)
return global_mean_pool(node_embs, batch)