Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import SAGEConv, global_mean_pool | |
| class GraphSAGEEncoder(nn.Module): | |
| """GraphSAGE-based encoder for learning node representations.""" | |
| def __init__( | |
| self, | |
| in_channels: int = 6, | |
| hidden_channels: int = 256, | |
| out_channels: int = 128, | |
| num_layers: int = 3, | |
| dropout: float = 0.1, | |
| ) -> None: | |
| """Initialize GraphSAGE encoder with configurable depth and width.""" | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.norms = nn.ModuleList() | |
| self.dropout = dropout | |
| self.convs.append(SAGEConv(in_channels, hidden_channels)) | |
| self.norms.append(nn.LayerNorm(hidden_channels)) | |
| for _ in range(num_layers - 2): | |
| self.convs.append(SAGEConv(hidden_channels, hidden_channels)) | |
| self.norms.append(nn.LayerNorm(hidden_channels)) | |
| self.convs.append(SAGEConv(hidden_channels, out_channels)) | |
| self.norms.append(nn.LayerNorm(out_channels)) | |
| def forward( | |
| self, x: torch.Tensor, edge_index: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Encode node features into embeddings.""" | |
| for i, (conv, norm) in enumerate(zip(self.convs, self.norms)): | |
| x = conv(x, edge_index) | |
| x = norm(x) | |
| if i < len(self.convs) - 1: | |
| x = F.relu(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| return x | |
| def graph_embedding( | |
| self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Compute graph-level embedding via mean pooling.""" | |
| node_embs = self.forward(x, edge_index) | |
| return global_mean_pool(node_embs, batch) | |