""" Graph Neural Networks integration for routing. Provides: - Node and edge feature generators - Model wrappers (GCN, GAT, GraphSAGE) - Training and inference utilities - Saving/loading weights - Application to online routing decisions """ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path from typing import Any, Literal import rustworkx as rx import torch from pydantic import BaseModel, Field from torch import nn from torch.nn import functional from torch.optim import AdamW from torch_geometric.loader import DataLoader from torch_geometric.nn import GATConv, GCNConv, SAGEConv from config.logging import logger TORCH_AVAILABLE = True PYG_AVAILABLE = True __all__ = [ # Models "BaseGNNModel", "DefaultFeatureGenerator", # Data classes "FeatureConfig", # Feature generation "FeatureGenerator", "GATRouter", "GCNRouter", # Enums "GNNModelType", # Inference "GNNRouterInference", # Training "GNNTrainer", "GraphSAGERouter", "RoutingPrediction", "RoutingStrategy", "TrainingConfig", # Factory "create_gnn_router", ] class GNNModelType(str, Enum): GCN = "gcn" GAT = "gat" SAGE = "sage" class RoutingStrategy(str, Enum): ARGMAX = "argmax" SOFTMAX_SAMPLE = "softmax_sample" TOP_K = "top_k" THRESHOLD = "threshold" class FeatureConfig(BaseModel): node_feature_dim: int = 64 edge_feature_dim: int = 16 embedding_dim: int = 128 use_embeddings: bool = True use_metrics: bool = True use_centrality: bool = True use_structural: bool = True normalize_features: bool = True clip_outliers: bool = True outlier_std: float = 3.0 class TrainingConfig(BaseModel): learning_rate: float = 1e-3 weight_decay: float = 1e-4 hidden_dim: int = 128 num_layers: int = 3 dropout: float = 0.1 gat_heads: int = 4 epochs: int = 100 batch_size: int = 32 patience: int = 10 use_batch_norm: bool = True use_residual: bool = True task: Literal["node_classification", "edge_prediction", "path_ranking"] = "node_classification" num_classes: int = 2 class RoutingPrediction(BaseModel): recommended_nodes: list[str] scores: dict[str, float] confidence: float strategy: RoutingStrategy metadata: dict[str, Any] = Field(default_factory=dict) class FeatureGenerator(ABC): @abstractmethod def generate_node_features( self, graph: Any, node_ids: list[str], metrics_tracker: Any | None = None, ) -> torch.Tensor: """Build a feature matrix for a list of nodes.""" @abstractmethod def generate_edge_features( self, graph: Any, edges: list[tuple[str, str]], metrics_tracker: Any | None = None, ) -> torch.Tensor: """Build a feature matrix for a list of edges.""" class DefaultFeatureGenerator(FeatureGenerator): def __init__(self, config: FeatureConfig | None = None): self.config = config or FeatureConfig() def generate_node_features( self, graph: Any, node_ids: list[str], metrics_tracker: Any | None = None, ) -> torch.Tensor: """Collect node features: embeddings, metrics, structure, centrality.""" features_list = [] if self.config.use_embeddings: emb_features = self._get_embedding_features(graph, node_ids) if emb_features is not None: features_list.append(emb_features) if self.config.use_metrics and metrics_tracker is not None: metric_features = metrics_tracker.get_node_features(node_ids) features_list.append(metric_features) if self.config.use_structural: struct_features = self._get_structural_features(graph, node_ids) features_list.append(struct_features) if self.config.use_centrality: centr_features = self._get_centrality_features(graph, node_ids) features_list.append(centr_features) if not features_list: return torch.eye(len(node_ids), dtype=torch.float32) combined = torch.cat(features_list, dim=1) if self.config.normalize_features: combined = self._normalize(combined) return combined.to(torch.float32) def generate_edge_features( self, graph: Any, edges: list[tuple[str, str]], metrics_tracker: Any | None = None, ) -> torch.Tensor: """Collect edge features: weight, reliability/latency/traffic metrics.""" features = [] for src, tgt in edges: edge_feat = [] weight = self._get_edge_weight(graph, src, tgt) edge_feat.append(weight) if metrics_tracker is not None: edge_metrics = metrics_tracker.get_edge_metrics(src, tgt) if edge_metrics: edge_feat.extend( [ edge_metrics.reliability, edge_metrics.avg_latency_ms / 1000.0, edge_metrics.avg_data_volume / 1000.0, ] ) else: edge_feat.extend([1.0, 0.0, 0.0]) else: edge_feat.extend([1.0, 0.0, 0.0]) features.append(edge_feat) result = torch.tensor(features, dtype=torch.float32) if self.config.normalize_features: result = self._normalize(result) return result def _get_embedding_features(self, graph: Any, _node_ids: list[str]) -> torch.Tensor | None: """Return the node embedding matrix if embeddings are available in the graph.""" if not hasattr(graph, "embeddings"): return None embeddings = graph.embeddings if embeddings is None or embeddings.numel() == 0: return None if isinstance(embeddings, torch.Tensor): return embeddings.to(torch.float32) return torch.tensor(embeddings, dtype=torch.float32) def _get_structural_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor: """Compute normalised in/out degrees for nodes.""" features = [] rx_graph = graph.graph id_to_idx = {} for idx in rx_graph.node_indices(): data = rx_graph.get_node_data(idx) if isinstance(data, dict): nid = data.get("id", str(idx)) elif hasattr(data, "agent_id"): nid = data.agent_id else: nid = str(idx) id_to_idx[nid] = idx num_nodes = rx_graph.num_nodes() for node_id in node_ids: idx = id_to_idx.get(node_id) if idx is not None: in_deg = rx_graph.in_degree(idx) / max(num_nodes - 1, 1) out_deg = rx_graph.out_degree(idx) / max(num_nodes - 1, 1) features.append([in_deg, out_deg]) else: features.append([0.0, 0.0]) return torch.tensor(features, dtype=torch.float32) def _get_centrality_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor: """Compute simple centrality (PageRank) for nodes.""" try: pagerank = rx.pagerank(graph.graph) id_to_idx = {} for idx in graph.graph.node_indices(): data = graph.graph.get_node_data(idx) if isinstance(data, dict): nid = data.get("id", str(idx)) elif hasattr(data, "agent_id"): nid = data.agent_id else: nid = str(idx) id_to_idx[nid] = idx features = [] for node_id in node_ids: idx = id_to_idx.get(node_id) if idx is not None and idx in pagerank: features.append([pagerank[idx]]) else: features.append([0.0]) return torch.tensor(features, dtype=torch.float32) except (ValueError, RuntimeError, KeyError) as e: logger.warning(f"Failed to compute centrality features: {e}") return torch.zeros((len(node_ids), 1), dtype=torch.float32) def _get_edge_weight(self, graph: Any, src: str, tgt: str) -> float: """Extract the graph edge weight or return 1.0 by default.""" try: rx_graph = graph.graph src_idx = tgt_idx = None for idx in rx_graph.node_indices(): data = rx_graph.get_node_data(idx) if isinstance(data, dict): nid = data.get("id", str(idx)) elif hasattr(data, "agent_id"): nid = data.agent_id else: nid = str(idx) if nid == src: src_idx = idx if nid == tgt: tgt_idx = idx if src_idx is not None and tgt_idx is not None: edge_data = rx_graph.get_edge_data(src_idx, tgt_idx) if isinstance(edge_data, dict): return edge_data.get("weight", 1.0) except (ValueError, RuntimeError, KeyError) as e: logger.debug(f"Failed to get edge weight: {e}") return 1.0 def _normalize(self, features: torch.Tensor) -> torch.Tensor: """Normalize features column-wise and clip outliers.""" if features.numel() == 0: return features mean = features.mean(dim=0, keepdim=True) std = features.std(dim=0, keepdim=True) std = torch.where(std == 0, torch.ones_like(std), std) normalized = (features - mean) / std if self.config.clip_outliers: normalized = torch.clamp( normalized, -self.config.outlier_std, self.config.outlier_std, ) return normalized if TORCH_AVAILABLE and PYG_AVAILABLE: class BaseGNNModel(nn.Module): """Base GNN architecture with input/output projection and residual layers.""" def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int = 3, dropout: float = 0.1, use_batch_norm: bool = True, use_residual: bool = True, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.dropout = dropout self.use_batch_norm = use_batch_norm self.use_residual = use_residual self.convs = nn.ModuleList() self.batch_norms = nn.ModuleList() self.input_proj = nn.Linear(in_channels, hidden_channels) self.output_proj = nn.Linear(hidden_channels, out_channels) def reset_parameters(self): """Reset the parameters of all model layers.""" for conv in self.convs: if hasattr(conv, "reset_parameters") and callable(conv.reset_parameters): conv.reset_parameters() for bn in self.batch_norms: if hasattr(bn, "reset_parameters") and callable(bn.reset_parameters): bn.reset_parameters() self.input_proj.reset_parameters() self.output_proj.reset_parameters() def forward( self, x: torch.Tensor, edge_index: torch.Tensor, _edge_attr: torch.Tensor | None = None ) -> torch.Tensor: """Forward pass: graph convolutions with Dropout, BatchNorm and residual.""" x = self.input_proj(x) x = functional.relu(x) for i, conv in enumerate(self.convs): x_prev = x x = conv(x, edge_index) if self.use_batch_norm and i < len(self.batch_norms): x = self.batch_norms[i](x) x = functional.relu(x) x = functional.dropout(x, p=self.dropout, training=self.training) if self.use_residual and x.shape == x_prev.shape: x = x + x_prev return self.output_proj(x) class GCNRouter(BaseGNNModel): def __init__(self, **kwargs): super().__init__(**kwargs) self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels)) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) for _ in range(self.num_layers - 1): self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels)) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) class GATRouter(BaseGNNModel): def __init__(self, heads: int = 4, **kwargs): super().__init__(**kwargs) self.heads = heads self.convs.append( GATConv( self.hidden_channels, self.hidden_channels // heads, heads=heads, dropout=self.dropout, ) ) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) for _ in range(self.num_layers - 1): self.convs.append( GATConv( self.hidden_channels, self.hidden_channels // heads, heads=heads, dropout=self.dropout, ) ) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) class GraphSAGERouter(BaseGNNModel): def __init__(self, **kwargs): super().__init__(**kwargs) self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels)) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) for _ in range(self.num_layers - 1): self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels)) self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels)) class TrainingResult(BaseModel): train_losses: list[float] val_losses: list[float] best_epoch: int best_val_loss: float metrics: dict[str, float] = Field(default_factory=dict) class GNNTrainer: """Training utility for GNN models for routing tasks.""" def __init__( self, model: Any, config: TrainingConfig | None = None, device: str | None = None, ): """Prepare the trainer, optimizer, and device.""" if not TORCH_AVAILABLE: msg = "PyTorch is required for GNNTrainer" raise ImportError(msg) self.config = config or TrainingConfig() self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) self.optimizer = AdamW( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay, ) self._best_model_state = None self._history: TrainingResult | None = None def train( self, train_data: list[Any], # list[Data] val_data: list[Any] | None = None, verbose: bool = True, ) -> TrainingResult: """Train the model on a train/val dataset with early stopping.""" if not PYG_AVAILABLE: msg = "PyTorch Geometric is required for training" raise ImportError(msg) train_loader = DataLoader(train_data, batch_size=self.config.batch_size, shuffle=True) val_loader = DataLoader(val_data, batch_size=self.config.batch_size) if val_data else None train_losses = [] val_losses = [] best_val_loss = float("inf") best_epoch = 0 patience_counter = 0 for epoch in range(self.config.epochs): train_loss = self._train_epoch(train_loader) train_losses.append(train_loss) if val_loader: val_loss = self._eval_epoch(val_loader) val_losses.append(val_loss) if val_loss < best_val_loss: best_val_loss = val_loss best_epoch = epoch patience_counter = 0 self._best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} else: patience_counter += 1 if patience_counter >= self.config.patience: if verbose: pass break if verbose and epoch % 10 == 0: f", val_loss={val_losses[-1]:.4f}" if val_losses else "" if self._best_model_state: self.model.load_state_dict(self._best_model_state) self._history = TrainingResult( train_losses=train_losses, val_losses=val_losses, best_epoch=best_epoch, best_val_loss=best_val_loss, ) return self._history def _train_epoch(self, loader: Any) -> float: """One training epoch, return the average loss.""" self.model.train() total_loss = 0.0 for batch_data in loader: batch_on_device = batch_data.to(self.device) self.optimizer.zero_grad() out = self.model(batch_on_device.x, batch_on_device.edge_index, getattr(batch_on_device, "edge_attr", None)) if self.config.task == "node_classification": loss = functional.cross_entropy(out, batch_on_device.y) elif self.config.task == "path_ranking": loss = self._compute_ranking_loss(out, batch_on_device) else: loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float()) loss.backward() self.optimizer.step() total_loss += loss.item() return total_loss / len(loader) def _eval_epoch(self, loader: Any) -> float: """Evaluate the model on the validation dataset, return average loss.""" self.model.eval() total_loss = 0.0 with torch.no_grad(): for batch_data in loader: batch_on_device = batch_data.to(self.device) edge_attr = getattr(batch_on_device, "edge_attr", None) out = self.model(batch_on_device.x, batch_on_device.edge_index, edge_attr) if self.config.task == "node_classification": loss = functional.cross_entropy(out, batch_on_device.y) else: loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float()) total_loss += loss.item() return total_loss / len(loader) def _compute_ranking_loss(self, out: Any, batch: Any) -> Any: """Simplest MSE loss for path ranking.""" scores = out.squeeze() targets = batch.y.float() return functional.mse_loss(scores, targets) def save(self, path: str | Path) -> None: """Save model weights and training history to a checkpoint.""" path = Path(path) torch.save( { "model_state_dict": self.model.state_dict(), "config": self.config, "history": self._history, }, path, ) def load(self, path: str | Path) -> None: """Load a checkpoint and restore the model/configuration.""" path = Path(path) checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) if "config" in checkpoint: self.config = checkpoint["config"] if "history" in checkpoint: self._history = checkpoint["history"] class GNNRouterInference: """Inference module: generates features, runs GNN and selects routes.""" def __init__( self, model: Any, feature_generator: FeatureGenerator | None = None, device: str | None = None, ): """Prepare the model for inference on the given device.""" if not TORCH_AVAILABLE: msg = "PyTorch is required for GNNRouterInference" raise ImportError(msg) self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) self.model.eval() self.feature_generator = feature_generator or DefaultFeatureGenerator() def predict( self, graph: Any, # RoleGraph source: str | None = None, candidates: list[str] | None = None, metrics_tracker: Any | None = None, strategy: RoutingStrategy = RoutingStrategy.ARGMAX, top_k: int = 3, threshold: float = 0.5, ) -> RoutingPrediction: """Predict the next routing nodes according to the selected strategy.""" if not PYG_AVAILABLE: msg = "PyTorch Geometric is required for inference" raise ImportError(msg) node_ids = graph.node_ids if hasattr(graph, "node_ids") else [] if not node_ids: node_ids = [str(i) for i in range(graph.num_nodes)] node_features = self.feature_generator.generate_node_features(graph, node_ids, metrics_tracker) edge_index = graph.edge_index # node_features and edge_index are already Tensors, just move to device if isinstance(node_features, torch.Tensor): x = node_features.to(dtype=torch.float32, device=self.device) else: x = torch.tensor(node_features, dtype=torch.float32, device=self.device) if isinstance(edge_index, torch.Tensor): edge_index_tensor = edge_index.to(dtype=torch.long, device=self.device) else: edge_index_tensor = torch.tensor(edge_index, dtype=torch.long, device=self.device) with torch.no_grad(): out = self.model(x, edge_index_tensor) # out shape: [N] or [N, C] where C = num_classes if out.dim() > 1 and out.shape[-1] > 1: # Multi-class: softmax over classes, take probability of the last (positive) class probs = functional.softmax(out, dim=-1) scores = probs[:, -1].cpu() else: scores = functional.softmax(out.squeeze(-1), dim=0).cpu() node_scores = {node_ids[i]: float(scores[i].item()) for i in range(len(node_ids))} if candidates: node_scores = {k: v for k, v in node_scores.items() if k in candidates} if source and source in node_scores: del node_scores[source] recommended = self._apply_strategy(node_scores, strategy, top_k, threshold) confidence = torch.mean(torch.tensor([node_scores[n] for n in recommended])).item() if recommended else 0.0 return RoutingPrediction( recommended_nodes=recommended, scores=node_scores, confidence=float(confidence), strategy=strategy, metadata={"source": source, "num_candidates": len(node_scores)}, ) def _apply_strategy( self, scores: dict[str, float], strategy: RoutingStrategy, top_k: int, threshold: float, ) -> list[str]: """Select recommended nodes according to the given selection strategy.""" if not scores: return [] if strategy == RoutingStrategy.ARGMAX: best_node = max(scores.keys(), key=lambda k: scores[k]) return [best_node] if strategy == RoutingStrategy.TOP_K: sorted_nodes = sorted(scores.keys(), key=lambda k: scores[k], reverse=True) return sorted_nodes[:top_k] if strategy == RoutingStrategy.THRESHOLD: return [n for n, s in scores.items() if s >= threshold] if strategy == RoutingStrategy.SOFTMAX_SAMPLE: nodes = list(scores.keys()) probs = torch.tensor([scores[n] for n in nodes]) probs = probs / probs.sum() idx = int(torch.multinomial(probs, 1).item()) return [nodes[idx]] return [] def get_all_scores( self, graph: Any, metrics_tracker: Any | None = None, ) -> dict[str, float]: """Get scores for all graph nodes.""" prediction = self.predict( graph, metrics_tracker=metrics_tracker, strategy=RoutingStrategy.TOP_K, top_k=graph.num_nodes, ) return prediction.scores @classmethod def load(cls, path: str | Path, feature_generator: FeatureGenerator | None = None) -> "GNNRouterInference": """Load the model and create an inference instance from a checkpoint.""" path = Path(path) checkpoint = torch.load(path, map_location="cpu") config = checkpoint.get("config", TrainingConfig()) model = GCNRouter( in_channels=64, hidden_channels=config.hidden_dim, out_channels=config.num_classes, num_layers=config.num_layers, dropout=config.dropout, ) model.load_state_dict(checkpoint["model_state_dict"]) return cls(model, feature_generator) def create_gnn_router( model_type: GNNModelType, in_channels: int, out_channels: int, config: TrainingConfig | None = None, ) -> Any: """Factory: create a GNN router of the required type (GCN/GAT/SAGE).""" if not TORCH_AVAILABLE or not PYG_AVAILABLE: msg = "PyTorch and PyTorch Geometric are required" raise ImportError(msg) config = config or TrainingConfig() kwargs = { "in_channels": in_channels, "hidden_channels": config.hidden_dim, "out_channels": out_channels, "num_layers": config.num_layers, "dropout": config.dropout, "use_batch_norm": config.use_batch_norm, "use_residual": config.use_residual, } if model_type == GNNModelType.GCN: return GCNRouter(**kwargs) if model_type == GNNModelType.GAT: return GATRouter(heads=int(config.gat_heads), **kwargs) if model_type == GNNModelType.SAGE: return GraphSAGERouter(**kwargs) msg = f"Unknown model type: {model_type}" raise ValueError(msg) if not TORCH_AVAILABLE or not PYG_AVAILABLE: class BaseGNNModel: def __init__(self, *_args, **_kwargs): msg = "PyTorch and PyTorch Geometric are required for GNN models" raise ImportError(msg) class GCNRouter(BaseGNNModel): pass class GATRouter(BaseGNNModel): pass class GraphSAGERouter(BaseGNNModel): pass