import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import SAGEConv, global_mean_pool class GraphSAGEEncoder(nn.Module): """GraphSAGE-based encoder for learning node representations.""" def __init__( self, in_channels: int = 6, hidden_channels: int = 256, out_channels: int = 128, num_layers: int = 3, dropout: float = 0.1, ) -> None: """Initialize GraphSAGE encoder with configurable depth and width.""" super().__init__() self.convs = nn.ModuleList() self.norms = nn.ModuleList() self.dropout = dropout self.convs.append(SAGEConv(in_channels, hidden_channels)) self.norms.append(nn.LayerNorm(hidden_channels)) for _ in range(num_layers - 2): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.norms.append(nn.LayerNorm(hidden_channels)) self.convs.append(SAGEConv(hidden_channels, out_channels)) self.norms.append(nn.LayerNorm(out_channels)) def forward( self, x: torch.Tensor, edge_index: torch.Tensor ) -> torch.Tensor: """Encode node features into embeddings.""" for i, (conv, norm) in enumerate(zip(self.convs, self.norms)): x = conv(x, edge_index) x = norm(x) if i < len(self.convs) - 1: x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) return x def graph_embedding( self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor ) -> torch.Tensor: """Compute graph-level embedding via mean pooling.""" node_embs = self.forward(x, edge_index) return global_mean_pool(node_embs, batch)