Spaces:
Sleeping
Sleeping
| """ | |
| Graph Neural Networks integration for routing. | |
| Provides: | |
| - Node and edge feature generators | |
| - Model wrappers (GCN, GAT, GraphSAGE) | |
| - Training and inference utilities | |
| - Saving/loading weights | |
| - Application to online routing decisions | |
| """ | |
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Any, Literal | |
| import rustworkx as rx | |
| import torch | |
| from pydantic import BaseModel, Field | |
| from torch import nn | |
| from torch.nn import functional | |
| from torch.optim import AdamW | |
| from torch_geometric.loader import DataLoader | |
| from torch_geometric.nn import GATConv, GCNConv, SAGEConv | |
| from config.logging import logger | |
| TORCH_AVAILABLE = True | |
| PYG_AVAILABLE = True | |
| __all__ = [ | |
| # Models | |
| "BaseGNNModel", | |
| "DefaultFeatureGenerator", | |
| # Data classes | |
| "FeatureConfig", | |
| # Feature generation | |
| "FeatureGenerator", | |
| "GATRouter", | |
| "GCNRouter", | |
| # Enums | |
| "GNNModelType", | |
| # Inference | |
| "GNNRouterInference", | |
| # Training | |
| "GNNTrainer", | |
| "GraphSAGERouter", | |
| "RoutingPrediction", | |
| "RoutingStrategy", | |
| "TrainingConfig", | |
| # Factory | |
| "create_gnn_router", | |
| ] | |
| class GNNModelType(str, Enum): | |
| GCN = "gcn" | |
| GAT = "gat" | |
| SAGE = "sage" | |
| class RoutingStrategy(str, Enum): | |
| ARGMAX = "argmax" | |
| SOFTMAX_SAMPLE = "softmax_sample" | |
| TOP_K = "top_k" | |
| THRESHOLD = "threshold" | |
| class FeatureConfig(BaseModel): | |
| node_feature_dim: int = 64 | |
| edge_feature_dim: int = 16 | |
| embedding_dim: int = 128 | |
| use_embeddings: bool = True | |
| use_metrics: bool = True | |
| use_centrality: bool = True | |
| use_structural: bool = True | |
| normalize_features: bool = True | |
| clip_outliers: bool = True | |
| outlier_std: float = 3.0 | |
| class TrainingConfig(BaseModel): | |
| learning_rate: float = 1e-3 | |
| weight_decay: float = 1e-4 | |
| hidden_dim: int = 128 | |
| num_layers: int = 3 | |
| dropout: float = 0.1 | |
| gat_heads: int = 4 | |
| epochs: int = 100 | |
| batch_size: int = 32 | |
| patience: int = 10 | |
| use_batch_norm: bool = True | |
| use_residual: bool = True | |
| task: Literal["node_classification", "edge_prediction", "path_ranking"] = "node_classification" | |
| num_classes: int = 2 | |
| class RoutingPrediction(BaseModel): | |
| recommended_nodes: list[str] | |
| scores: dict[str, float] | |
| confidence: float | |
| strategy: RoutingStrategy | |
| metadata: dict[str, Any] = Field(default_factory=dict) | |
| class FeatureGenerator(ABC): | |
| def generate_node_features( | |
| self, | |
| graph: Any, | |
| node_ids: list[str], | |
| metrics_tracker: Any | None = None, | |
| ) -> torch.Tensor: | |
| """Build a feature matrix for a list of nodes.""" | |
| def generate_edge_features( | |
| self, | |
| graph: Any, | |
| edges: list[tuple[str, str]], | |
| metrics_tracker: Any | None = None, | |
| ) -> torch.Tensor: | |
| """Build a feature matrix for a list of edges.""" | |
| class DefaultFeatureGenerator(FeatureGenerator): | |
| def __init__(self, config: FeatureConfig | None = None): | |
| self.config = config or FeatureConfig() | |
| def generate_node_features( | |
| self, | |
| graph: Any, | |
| node_ids: list[str], | |
| metrics_tracker: Any | None = None, | |
| ) -> torch.Tensor: | |
| """Collect node features: embeddings, metrics, structure, centrality.""" | |
| features_list = [] | |
| if self.config.use_embeddings: | |
| emb_features = self._get_embedding_features(graph, node_ids) | |
| if emb_features is not None: | |
| features_list.append(emb_features) | |
| if self.config.use_metrics and metrics_tracker is not None: | |
| metric_features = metrics_tracker.get_node_features(node_ids) | |
| features_list.append(metric_features) | |
| if self.config.use_structural: | |
| struct_features = self._get_structural_features(graph, node_ids) | |
| features_list.append(struct_features) | |
| if self.config.use_centrality: | |
| centr_features = self._get_centrality_features(graph, node_ids) | |
| features_list.append(centr_features) | |
| if not features_list: | |
| return torch.eye(len(node_ids), dtype=torch.float32) | |
| combined = torch.cat(features_list, dim=1) | |
| if self.config.normalize_features: | |
| combined = self._normalize(combined) | |
| return combined.to(torch.float32) | |
| def generate_edge_features( | |
| self, | |
| graph: Any, | |
| edges: list[tuple[str, str]], | |
| metrics_tracker: Any | None = None, | |
| ) -> torch.Tensor: | |
| """Collect edge features: weight, reliability/latency/traffic metrics.""" | |
| features = [] | |
| for src, tgt in edges: | |
| edge_feat = [] | |
| weight = self._get_edge_weight(graph, src, tgt) | |
| edge_feat.append(weight) | |
| if metrics_tracker is not None: | |
| edge_metrics = metrics_tracker.get_edge_metrics(src, tgt) | |
| if edge_metrics: | |
| edge_feat.extend( | |
| [ | |
| edge_metrics.reliability, | |
| edge_metrics.avg_latency_ms / 1000.0, | |
| edge_metrics.avg_data_volume / 1000.0, | |
| ] | |
| ) | |
| else: | |
| edge_feat.extend([1.0, 0.0, 0.0]) | |
| else: | |
| edge_feat.extend([1.0, 0.0, 0.0]) | |
| features.append(edge_feat) | |
| result = torch.tensor(features, dtype=torch.float32) | |
| if self.config.normalize_features: | |
| result = self._normalize(result) | |
| return result | |
| def _get_embedding_features(self, graph: Any, _node_ids: list[str]) -> torch.Tensor | None: | |
| """Return the node embedding matrix if embeddings are available in the graph.""" | |
| if not hasattr(graph, "embeddings"): | |
| return None | |
| embeddings = graph.embeddings | |
| if embeddings is None or embeddings.numel() == 0: | |
| return None | |
| if isinstance(embeddings, torch.Tensor): | |
| return embeddings.to(torch.float32) | |
| return torch.tensor(embeddings, dtype=torch.float32) | |
| def _get_structural_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor: | |
| """Compute normalised in/out degrees for nodes.""" | |
| features = [] | |
| rx_graph = graph.graph | |
| id_to_idx = {} | |
| for idx in rx_graph.node_indices(): | |
| data = rx_graph.get_node_data(idx) | |
| if isinstance(data, dict): | |
| nid = data.get("id", str(idx)) | |
| elif hasattr(data, "agent_id"): | |
| nid = data.agent_id | |
| else: | |
| nid = str(idx) | |
| id_to_idx[nid] = idx | |
| num_nodes = rx_graph.num_nodes() | |
| for node_id in node_ids: | |
| idx = id_to_idx.get(node_id) | |
| if idx is not None: | |
| in_deg = rx_graph.in_degree(idx) / max(num_nodes - 1, 1) | |
| out_deg = rx_graph.out_degree(idx) / max(num_nodes - 1, 1) | |
| features.append([in_deg, out_deg]) | |
| else: | |
| features.append([0.0, 0.0]) | |
| return torch.tensor(features, dtype=torch.float32) | |
| def _get_centrality_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor: | |
| """Compute simple centrality (PageRank) for nodes.""" | |
| try: | |
| pagerank = rx.pagerank(graph.graph) | |
| id_to_idx = {} | |
| for idx in graph.graph.node_indices(): | |
| data = graph.graph.get_node_data(idx) | |
| if isinstance(data, dict): | |
| nid = data.get("id", str(idx)) | |
| elif hasattr(data, "agent_id"): | |
| nid = data.agent_id | |
| else: | |
| nid = str(idx) | |
| id_to_idx[nid] = idx | |
| features = [] | |
| for node_id in node_ids: | |
| idx = id_to_idx.get(node_id) | |
| if idx is not None and idx in pagerank: | |
| features.append([pagerank[idx]]) | |
| else: | |
| features.append([0.0]) | |
| return torch.tensor(features, dtype=torch.float32) | |
| except (ValueError, RuntimeError, KeyError) as e: | |
| logger.warning(f"Failed to compute centrality features: {e}") | |
| return torch.zeros((len(node_ids), 1), dtype=torch.float32) | |
| def _get_edge_weight(self, graph: Any, src: str, tgt: str) -> float: | |
| """Extract the graph edge weight or return 1.0 by default.""" | |
| try: | |
| rx_graph = graph.graph | |
| src_idx = tgt_idx = None | |
| for idx in rx_graph.node_indices(): | |
| data = rx_graph.get_node_data(idx) | |
| if isinstance(data, dict): | |
| nid = data.get("id", str(idx)) | |
| elif hasattr(data, "agent_id"): | |
| nid = data.agent_id | |
| else: | |
| nid = str(idx) | |
| if nid == src: | |
| src_idx = idx | |
| if nid == tgt: | |
| tgt_idx = idx | |
| if src_idx is not None and tgt_idx is not None: | |
| edge_data = rx_graph.get_edge_data(src_idx, tgt_idx) | |
| if isinstance(edge_data, dict): | |
| return edge_data.get("weight", 1.0) | |
| except (ValueError, RuntimeError, KeyError) as e: | |
| logger.debug(f"Failed to get edge weight: {e}") | |
| return 1.0 | |
| def _normalize(self, features: torch.Tensor) -> torch.Tensor: | |
| """Normalize features column-wise and clip outliers.""" | |
| if features.numel() == 0: | |
| return features | |
| mean = features.mean(dim=0, keepdim=True) | |
| std = features.std(dim=0, keepdim=True) | |
| std = torch.where(std == 0, torch.ones_like(std), std) | |
| normalized = (features - mean) / std | |
| if self.config.clip_outliers: | |
| normalized = torch.clamp( | |
| normalized, | |
| -self.config.outlier_std, | |
| self.config.outlier_std, | |
| ) | |
| return normalized | |
| if TORCH_AVAILABLE and PYG_AVAILABLE: | |
| class BaseGNNModel(nn.Module): | |
| """Base GNN architecture with input/output projection and residual layers.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| hidden_channels: int, | |
| out_channels: int, | |
| num_layers: int = 3, | |
| dropout: float = 0.1, | |
| use_batch_norm: bool = True, | |
| use_residual: bool = True, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.hidden_channels = hidden_channels | |
| self.out_channels = out_channels | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| self.use_batch_norm = use_batch_norm | |
| self.use_residual = use_residual | |
| self.convs = nn.ModuleList() | |
| self.batch_norms = nn.ModuleList() | |
| self.input_proj = nn.Linear(in_channels, hidden_channels) | |
| self.output_proj = nn.Linear(hidden_channels, out_channels) | |
| def reset_parameters(self): | |
| """Reset the parameters of all model layers.""" | |
| for conv in self.convs: | |
| if hasattr(conv, "reset_parameters") and callable(conv.reset_parameters): | |
| conv.reset_parameters() | |
| for bn in self.batch_norms: | |
| if hasattr(bn, "reset_parameters") and callable(bn.reset_parameters): | |
| bn.reset_parameters() | |
| self.input_proj.reset_parameters() | |
| self.output_proj.reset_parameters() | |
| def forward( | |
| self, x: torch.Tensor, edge_index: torch.Tensor, _edge_attr: torch.Tensor | None = None | |
| ) -> torch.Tensor: | |
| """Forward pass: graph convolutions with Dropout, BatchNorm and residual.""" | |
| x = self.input_proj(x) | |
| x = functional.relu(x) | |
| for i, conv in enumerate(self.convs): | |
| x_prev = x | |
| x = conv(x, edge_index) | |
| if self.use_batch_norm and i < len(self.batch_norms): | |
| x = self.batch_norms[i](x) | |
| x = functional.relu(x) | |
| x = functional.dropout(x, p=self.dropout, training=self.training) | |
| if self.use_residual and x.shape == x_prev.shape: | |
| x = x + x_prev | |
| return self.output_proj(x) | |
| class GCNRouter(BaseGNNModel): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels)) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| for _ in range(self.num_layers - 1): | |
| self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels)) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| class GATRouter(BaseGNNModel): | |
| def __init__(self, heads: int = 4, **kwargs): | |
| super().__init__(**kwargs) | |
| self.heads = heads | |
| self.convs.append( | |
| GATConv( | |
| self.hidden_channels, | |
| self.hidden_channels // heads, | |
| heads=heads, | |
| dropout=self.dropout, | |
| ) | |
| ) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| for _ in range(self.num_layers - 1): | |
| self.convs.append( | |
| GATConv( | |
| self.hidden_channels, | |
| self.hidden_channels // heads, | |
| heads=heads, | |
| dropout=self.dropout, | |
| ) | |
| ) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| class GraphSAGERouter(BaseGNNModel): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels)) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| for _ in range(self.num_layers - 1): | |
| self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels)) | |
| self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) | |
| class TrainingResult(BaseModel): | |
| train_losses: list[float] | |
| val_losses: list[float] | |
| best_epoch: int | |
| best_val_loss: float | |
| metrics: dict[str, float] = Field(default_factory=dict) | |
| class GNNTrainer: | |
| """Training utility for GNN models for routing tasks.""" | |
| def __init__( | |
| self, | |
| model: Any, | |
| config: TrainingConfig | None = None, | |
| device: str | None = None, | |
| ): | |
| """Prepare the trainer, optimizer, and device.""" | |
| if not TORCH_AVAILABLE: | |
| msg = "PyTorch is required for GNNTrainer" | |
| raise ImportError(msg) | |
| self.config = config or TrainingConfig() | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = model.to(self.device) | |
| self.optimizer = AdamW( | |
| self.model.parameters(), | |
| lr=self.config.learning_rate, | |
| weight_decay=self.config.weight_decay, | |
| ) | |
| self._best_model_state = None | |
| self._history: TrainingResult | None = None | |
| def train( | |
| self, | |
| train_data: list[Any], # list[Data] | |
| val_data: list[Any] | None = None, | |
| verbose: bool = True, | |
| ) -> TrainingResult: | |
| """Train the model on a train/val dataset with early stopping.""" | |
| if not PYG_AVAILABLE: | |
| msg = "PyTorch Geometric is required for training" | |
| raise ImportError(msg) | |
| train_loader = DataLoader(train_data, batch_size=self.config.batch_size, shuffle=True) | |
| val_loader = DataLoader(val_data, batch_size=self.config.batch_size) if val_data else None | |
| train_losses = [] | |
| val_losses = [] | |
| best_val_loss = float("inf") | |
| best_epoch = 0 | |
| patience_counter = 0 | |
| for epoch in range(self.config.epochs): | |
| train_loss = self._train_epoch(train_loader) | |
| train_losses.append(train_loss) | |
| if val_loader: | |
| val_loss = self._eval_epoch(val_loader) | |
| val_losses.append(val_loss) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| best_epoch = epoch | |
| patience_counter = 0 | |
| self._best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= self.config.patience: | |
| if verbose: | |
| pass | |
| break | |
| if verbose and epoch % 10 == 0: | |
| f", val_loss={val_losses[-1]:.4f}" if val_losses else "" | |
| if self._best_model_state: | |
| self.model.load_state_dict(self._best_model_state) | |
| self._history = TrainingResult( | |
| train_losses=train_losses, | |
| val_losses=val_losses, | |
| best_epoch=best_epoch, | |
| best_val_loss=best_val_loss, | |
| ) | |
| return self._history | |
| def _train_epoch(self, loader: Any) -> float: | |
| """One training epoch, return the average loss.""" | |
| self.model.train() | |
| total_loss = 0.0 | |
| for batch_data in loader: | |
| batch_on_device = batch_data.to(self.device) | |
| self.optimizer.zero_grad() | |
| out = self.model(batch_on_device.x, batch_on_device.edge_index, getattr(batch_on_device, "edge_attr", None)) | |
| if self.config.task == "node_classification": | |
| loss = functional.cross_entropy(out, batch_on_device.y) | |
| elif self.config.task == "path_ranking": | |
| loss = self._compute_ranking_loss(out, batch_on_device) | |
| else: | |
| loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float()) | |
| loss.backward() | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| return total_loss / len(loader) | |
| def _eval_epoch(self, loader: Any) -> float: | |
| """Evaluate the model on the validation dataset, return average loss.""" | |
| self.model.eval() | |
| total_loss = 0.0 | |
| with torch.no_grad(): | |
| for batch_data in loader: | |
| batch_on_device = batch_data.to(self.device) | |
| edge_attr = getattr(batch_on_device, "edge_attr", None) | |
| out = self.model(batch_on_device.x, batch_on_device.edge_index, edge_attr) | |
| if self.config.task == "node_classification": | |
| loss = functional.cross_entropy(out, batch_on_device.y) | |
| else: | |
| loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float()) | |
| total_loss += loss.item() | |
| return total_loss / len(loader) | |
| def _compute_ranking_loss(self, out: Any, batch: Any) -> Any: | |
| """Simplest MSE loss for path ranking.""" | |
| scores = out.squeeze() | |
| targets = batch.y.float() | |
| return functional.mse_loss(scores, targets) | |
| def save(self, path: str | Path) -> None: | |
| """Save model weights and training history to a checkpoint.""" | |
| path = Path(path) | |
| torch.save( | |
| { | |
| "model_state_dict": self.model.state_dict(), | |
| "config": self.config, | |
| "history": self._history, | |
| }, | |
| path, | |
| ) | |
| def load(self, path: str | Path) -> None: | |
| """Load a checkpoint and restore the model/configuration.""" | |
| path = Path(path) | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| if "config" in checkpoint: | |
| self.config = checkpoint["config"] | |
| if "history" in checkpoint: | |
| self._history = checkpoint["history"] | |
| class GNNRouterInference: | |
| """Inference module: generates features, runs GNN and selects routes.""" | |
| def __init__( | |
| self, | |
| model: Any, | |
| feature_generator: FeatureGenerator | None = None, | |
| device: str | None = None, | |
| ): | |
| """Prepare the model for inference on the given device.""" | |
| if not TORCH_AVAILABLE: | |
| msg = "PyTorch is required for GNNRouterInference" | |
| raise ImportError(msg) | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = model.to(self.device) | |
| self.model.eval() | |
| self.feature_generator = feature_generator or DefaultFeatureGenerator() | |
| def predict( | |
| self, | |
| graph: Any, # RoleGraph | |
| source: str | None = None, | |
| candidates: list[str] | None = None, | |
| metrics_tracker: Any | None = None, | |
| strategy: RoutingStrategy = RoutingStrategy.ARGMAX, | |
| top_k: int = 3, | |
| threshold: float = 0.5, | |
| ) -> RoutingPrediction: | |
| """Predict the next routing nodes according to the selected strategy.""" | |
| if not PYG_AVAILABLE: | |
| msg = "PyTorch Geometric is required for inference" | |
| raise ImportError(msg) | |
| node_ids = graph.node_ids if hasattr(graph, "node_ids") else [] | |
| if not node_ids: | |
| node_ids = [str(i) for i in range(graph.num_nodes)] | |
| node_features = self.feature_generator.generate_node_features(graph, node_ids, metrics_tracker) | |
| edge_index = graph.edge_index | |
| # node_features and edge_index are already Tensors, just move to device | |
| if isinstance(node_features, torch.Tensor): | |
| x = node_features.to(dtype=torch.float32, device=self.device) | |
| else: | |
| x = torch.tensor(node_features, dtype=torch.float32, device=self.device) | |
| if isinstance(edge_index, torch.Tensor): | |
| edge_index_tensor = edge_index.to(dtype=torch.long, device=self.device) | |
| else: | |
| edge_index_tensor = torch.tensor(edge_index, dtype=torch.long, device=self.device) | |
| with torch.no_grad(): | |
| out = self.model(x, edge_index_tensor) | |
| # out shape: [N] or [N, C] where C = num_classes | |
| if out.dim() > 1 and out.shape[-1] > 1: | |
| # Multi-class: softmax over classes, take probability of the last (positive) class | |
| probs = functional.softmax(out, dim=-1) | |
| scores = probs[:, -1].cpu() | |
| else: | |
| scores = functional.softmax(out.squeeze(-1), dim=0).cpu() | |
| node_scores = {node_ids[i]: float(scores[i].item()) for i in range(len(node_ids))} | |
| if candidates: | |
| node_scores = {k: v for k, v in node_scores.items() if k in candidates} | |
| if source and source in node_scores: | |
| del node_scores[source] | |
| recommended = self._apply_strategy(node_scores, strategy, top_k, threshold) | |
| confidence = torch.mean(torch.tensor([node_scores[n] for n in recommended])).item() if recommended else 0.0 | |
| return RoutingPrediction( | |
| recommended_nodes=recommended, | |
| scores=node_scores, | |
| confidence=float(confidence), | |
| strategy=strategy, | |
| metadata={"source": source, "num_candidates": len(node_scores)}, | |
| ) | |
| def _apply_strategy( | |
| self, | |
| scores: dict[str, float], | |
| strategy: RoutingStrategy, | |
| top_k: int, | |
| threshold: float, | |
| ) -> list[str]: | |
| """Select recommended nodes according to the given selection strategy.""" | |
| if not scores: | |
| return [] | |
| if strategy == RoutingStrategy.ARGMAX: | |
| best_node = max(scores.keys(), key=lambda k: scores[k]) | |
| return [best_node] | |
| if strategy == RoutingStrategy.TOP_K: | |
| sorted_nodes = sorted(scores.keys(), key=lambda k: scores[k], reverse=True) | |
| return sorted_nodes[:top_k] | |
| if strategy == RoutingStrategy.THRESHOLD: | |
| return [n for n, s in scores.items() if s >= threshold] | |
| if strategy == RoutingStrategy.SOFTMAX_SAMPLE: | |
| nodes = list(scores.keys()) | |
| probs = torch.tensor([scores[n] for n in nodes]) | |
| probs = probs / probs.sum() | |
| idx = int(torch.multinomial(probs, 1).item()) | |
| return [nodes[idx]] | |
| return [] | |
| def get_all_scores( | |
| self, | |
| graph: Any, | |
| metrics_tracker: Any | None = None, | |
| ) -> dict[str, float]: | |
| """Get scores for all graph nodes.""" | |
| prediction = self.predict( | |
| graph, | |
| metrics_tracker=metrics_tracker, | |
| strategy=RoutingStrategy.TOP_K, | |
| top_k=graph.num_nodes, | |
| ) | |
| return prediction.scores | |
| def load(cls, path: str | Path, feature_generator: FeatureGenerator | None = None) -> "GNNRouterInference": | |
| """Load the model and create an inference instance from a checkpoint.""" | |
| path = Path(path) | |
| checkpoint = torch.load(path, map_location="cpu") | |
| config = checkpoint.get("config", TrainingConfig()) | |
| model = GCNRouter( | |
| in_channels=64, | |
| hidden_channels=config.hidden_dim, | |
| out_channels=config.num_classes, | |
| num_layers=config.num_layers, | |
| dropout=config.dropout, | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| return cls(model, feature_generator) | |
| def create_gnn_router( | |
| model_type: GNNModelType, | |
| in_channels: int, | |
| out_channels: int, | |
| config: TrainingConfig | None = None, | |
| ) -> Any: | |
| """Factory: create a GNN router of the required type (GCN/GAT/SAGE).""" | |
| if not TORCH_AVAILABLE or not PYG_AVAILABLE: | |
| msg = "PyTorch and PyTorch Geometric are required" | |
| raise ImportError(msg) | |
| config = config or TrainingConfig() | |
| kwargs = { | |
| "in_channels": in_channels, | |
| "hidden_channels": config.hidden_dim, | |
| "out_channels": out_channels, | |
| "num_layers": config.num_layers, | |
| "dropout": config.dropout, | |
| "use_batch_norm": config.use_batch_norm, | |
| "use_residual": config.use_residual, | |
| } | |
| if model_type == GNNModelType.GCN: | |
| return GCNRouter(**kwargs) | |
| if model_type == GNNModelType.GAT: | |
| return GATRouter(heads=int(config.gat_heads), **kwargs) | |
| if model_type == GNNModelType.SAGE: | |
| return GraphSAGERouter(**kwargs) | |
| msg = f"Unknown model type: {model_type}" | |
| raise ValueError(msg) | |
| if not TORCH_AVAILABLE or not PYG_AVAILABLE: | |
| class BaseGNNModel: | |
| def __init__(self, *_args, **_kwargs): | |
| msg = "PyTorch and PyTorch Geometric are required for GNN models" | |
| raise ImportError(msg) | |
| class GCNRouter(BaseGNNModel): | |
| pass | |
| class GATRouter(BaseGNNModel): | |
| pass | |
| class GraphSAGERouter(BaseGNNModel): | |
| pass | |