gMAS / src /core /gnn.py
Артём Боярских
chore: initial commit
3193174
"""
Graph Neural Networks integration for routing.
Provides:
- Node and edge feature generators
- Model wrappers (GCN, GAT, GraphSAGE)
- Training and inference utilities
- Saving/loading weights
- Application to online routing decisions
"""
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Literal
import rustworkx as rx
import torch
from pydantic import BaseModel, Field
from torch import nn
from torch.nn import functional
from torch.optim import AdamW
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, GCNConv, SAGEConv
from config.logging import logger
TORCH_AVAILABLE = True
PYG_AVAILABLE = True
__all__ = [
# Models
"BaseGNNModel",
"DefaultFeatureGenerator",
# Data classes
"FeatureConfig",
# Feature generation
"FeatureGenerator",
"GATRouter",
"GCNRouter",
# Enums
"GNNModelType",
# Inference
"GNNRouterInference",
# Training
"GNNTrainer",
"GraphSAGERouter",
"RoutingPrediction",
"RoutingStrategy",
"TrainingConfig",
# Factory
"create_gnn_router",
]
class GNNModelType(str, Enum):
GCN = "gcn"
GAT = "gat"
SAGE = "sage"
class RoutingStrategy(str, Enum):
ARGMAX = "argmax"
SOFTMAX_SAMPLE = "softmax_sample"
TOP_K = "top_k"
THRESHOLD = "threshold"
class FeatureConfig(BaseModel):
node_feature_dim: int = 64
edge_feature_dim: int = 16
embedding_dim: int = 128
use_embeddings: bool = True
use_metrics: bool = True
use_centrality: bool = True
use_structural: bool = True
normalize_features: bool = True
clip_outliers: bool = True
outlier_std: float = 3.0
class TrainingConfig(BaseModel):
learning_rate: float = 1e-3
weight_decay: float = 1e-4
hidden_dim: int = 128
num_layers: int = 3
dropout: float = 0.1
gat_heads: int = 4
epochs: int = 100
batch_size: int = 32
patience: int = 10
use_batch_norm: bool = True
use_residual: bool = True
task: Literal["node_classification", "edge_prediction", "path_ranking"] = "node_classification"
num_classes: int = 2
class RoutingPrediction(BaseModel):
recommended_nodes: list[str]
scores: dict[str, float]
confidence: float
strategy: RoutingStrategy
metadata: dict[str, Any] = Field(default_factory=dict)
class FeatureGenerator(ABC):
@abstractmethod
def generate_node_features(
self,
graph: Any,
node_ids: list[str],
metrics_tracker: Any | None = None,
) -> torch.Tensor:
"""Build a feature matrix for a list of nodes."""
@abstractmethod
def generate_edge_features(
self,
graph: Any,
edges: list[tuple[str, str]],
metrics_tracker: Any | None = None,
) -> torch.Tensor:
"""Build a feature matrix for a list of edges."""
class DefaultFeatureGenerator(FeatureGenerator):
def __init__(self, config: FeatureConfig | None = None):
self.config = config or FeatureConfig()
def generate_node_features(
self,
graph: Any,
node_ids: list[str],
metrics_tracker: Any | None = None,
) -> torch.Tensor:
"""Collect node features: embeddings, metrics, structure, centrality."""
features_list = []
if self.config.use_embeddings:
emb_features = self._get_embedding_features(graph, node_ids)
if emb_features is not None:
features_list.append(emb_features)
if self.config.use_metrics and metrics_tracker is not None:
metric_features = metrics_tracker.get_node_features(node_ids)
features_list.append(metric_features)
if self.config.use_structural:
struct_features = self._get_structural_features(graph, node_ids)
features_list.append(struct_features)
if self.config.use_centrality:
centr_features = self._get_centrality_features(graph, node_ids)
features_list.append(centr_features)
if not features_list:
return torch.eye(len(node_ids), dtype=torch.float32)
combined = torch.cat(features_list, dim=1)
if self.config.normalize_features:
combined = self._normalize(combined)
return combined.to(torch.float32)
def generate_edge_features(
self,
graph: Any,
edges: list[tuple[str, str]],
metrics_tracker: Any | None = None,
) -> torch.Tensor:
"""Collect edge features: weight, reliability/latency/traffic metrics."""
features = []
for src, tgt in edges:
edge_feat = []
weight = self._get_edge_weight(graph, src, tgt)
edge_feat.append(weight)
if metrics_tracker is not None:
edge_metrics = metrics_tracker.get_edge_metrics(src, tgt)
if edge_metrics:
edge_feat.extend(
[
edge_metrics.reliability,
edge_metrics.avg_latency_ms / 1000.0,
edge_metrics.avg_data_volume / 1000.0,
]
)
else:
edge_feat.extend([1.0, 0.0, 0.0])
else:
edge_feat.extend([1.0, 0.0, 0.0])
features.append(edge_feat)
result = torch.tensor(features, dtype=torch.float32)
if self.config.normalize_features:
result = self._normalize(result)
return result
def _get_embedding_features(self, graph: Any, _node_ids: list[str]) -> torch.Tensor | None:
"""Return the node embedding matrix if embeddings are available in the graph."""
if not hasattr(graph, "embeddings"):
return None
embeddings = graph.embeddings
if embeddings is None or embeddings.numel() == 0:
return None
if isinstance(embeddings, torch.Tensor):
return embeddings.to(torch.float32)
return torch.tensor(embeddings, dtype=torch.float32)
def _get_structural_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor:
"""Compute normalised in/out degrees for nodes."""
features = []
rx_graph = graph.graph
id_to_idx = {}
for idx in rx_graph.node_indices():
data = rx_graph.get_node_data(idx)
if isinstance(data, dict):
nid = data.get("id", str(idx))
elif hasattr(data, "agent_id"):
nid = data.agent_id
else:
nid = str(idx)
id_to_idx[nid] = idx
num_nodes = rx_graph.num_nodes()
for node_id in node_ids:
idx = id_to_idx.get(node_id)
if idx is not None:
in_deg = rx_graph.in_degree(idx) / max(num_nodes - 1, 1)
out_deg = rx_graph.out_degree(idx) / max(num_nodes - 1, 1)
features.append([in_deg, out_deg])
else:
features.append([0.0, 0.0])
return torch.tensor(features, dtype=torch.float32)
def _get_centrality_features(self, graph: Any, node_ids: list[str]) -> torch.Tensor:
"""Compute simple centrality (PageRank) for nodes."""
try:
pagerank = rx.pagerank(graph.graph)
id_to_idx = {}
for idx in graph.graph.node_indices():
data = graph.graph.get_node_data(idx)
if isinstance(data, dict):
nid = data.get("id", str(idx))
elif hasattr(data, "agent_id"):
nid = data.agent_id
else:
nid = str(idx)
id_to_idx[nid] = idx
features = []
for node_id in node_ids:
idx = id_to_idx.get(node_id)
if idx is not None and idx in pagerank:
features.append([pagerank[idx]])
else:
features.append([0.0])
return torch.tensor(features, dtype=torch.float32)
except (ValueError, RuntimeError, KeyError) as e:
logger.warning(f"Failed to compute centrality features: {e}")
return torch.zeros((len(node_ids), 1), dtype=torch.float32)
def _get_edge_weight(self, graph: Any, src: str, tgt: str) -> float:
"""Extract the graph edge weight or return 1.0 by default."""
try:
rx_graph = graph.graph
src_idx = tgt_idx = None
for idx in rx_graph.node_indices():
data = rx_graph.get_node_data(idx)
if isinstance(data, dict):
nid = data.get("id", str(idx))
elif hasattr(data, "agent_id"):
nid = data.agent_id
else:
nid = str(idx)
if nid == src:
src_idx = idx
if nid == tgt:
tgt_idx = idx
if src_idx is not None and tgt_idx is not None:
edge_data = rx_graph.get_edge_data(src_idx, tgt_idx)
if isinstance(edge_data, dict):
return edge_data.get("weight", 1.0)
except (ValueError, RuntimeError, KeyError) as e:
logger.debug(f"Failed to get edge weight: {e}")
return 1.0
def _normalize(self, features: torch.Tensor) -> torch.Tensor:
"""Normalize features column-wise and clip outliers."""
if features.numel() == 0:
return features
mean = features.mean(dim=0, keepdim=True)
std = features.std(dim=0, keepdim=True)
std = torch.where(std == 0, torch.ones_like(std), std)
normalized = (features - mean) / std
if self.config.clip_outliers:
normalized = torch.clamp(
normalized,
-self.config.outlier_std,
self.config.outlier_std,
)
return normalized
if TORCH_AVAILABLE and PYG_AVAILABLE:
class BaseGNNModel(nn.Module):
"""Base GNN architecture with input/output projection and residual layers."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 3,
dropout: float = 0.1,
use_batch_norm: bool = True,
use_residual: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.dropout = dropout
self.use_batch_norm = use_batch_norm
self.use_residual = use_residual
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
self.input_proj = nn.Linear(in_channels, hidden_channels)
self.output_proj = nn.Linear(hidden_channels, out_channels)
def reset_parameters(self):
"""Reset the parameters of all model layers."""
for conv in self.convs:
if hasattr(conv, "reset_parameters") and callable(conv.reset_parameters):
conv.reset_parameters()
for bn in self.batch_norms:
if hasattr(bn, "reset_parameters") and callable(bn.reset_parameters):
bn.reset_parameters()
self.input_proj.reset_parameters()
self.output_proj.reset_parameters()
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor, _edge_attr: torch.Tensor | None = None
) -> torch.Tensor:
"""Forward pass: graph convolutions with Dropout, BatchNorm and residual."""
x = self.input_proj(x)
x = functional.relu(x)
for i, conv in enumerate(self.convs):
x_prev = x
x = conv(x, edge_index)
if self.use_batch_norm and i < len(self.batch_norms):
x = self.batch_norms[i](x)
x = functional.relu(x)
x = functional.dropout(x, p=self.dropout, training=self.training)
if self.use_residual and x.shape == x_prev.shape:
x = x + x_prev
return self.output_proj(x)
class GCNRouter(BaseGNNModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels))
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
for _ in range(self.num_layers - 1):
self.convs.append(GCNConv(self.hidden_channels, self.hidden_channels))
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
class GATRouter(BaseGNNModel):
def __init__(self, heads: int = 4, **kwargs):
super().__init__(**kwargs)
self.heads = heads
self.convs.append(
GATConv(
self.hidden_channels,
self.hidden_channels // heads,
heads=heads,
dropout=self.dropout,
)
)
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
for _ in range(self.num_layers - 1):
self.convs.append(
GATConv(
self.hidden_channels,
self.hidden_channels // heads,
heads=heads,
dropout=self.dropout,
)
)
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
class GraphSAGERouter(BaseGNNModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels))
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
for _ in range(self.num_layers - 1):
self.convs.append(SAGEConv(self.hidden_channels, self.hidden_channels))
self.batch_norms.append(nn.BatchNorm1d(self.hidden_channels))
class TrainingResult(BaseModel):
train_losses: list[float]
val_losses: list[float]
best_epoch: int
best_val_loss: float
metrics: dict[str, float] = Field(default_factory=dict)
class GNNTrainer:
"""Training utility for GNN models for routing tasks."""
def __init__(
self,
model: Any,
config: TrainingConfig | None = None,
device: str | None = None,
):
"""Prepare the trainer, optimizer, and device."""
if not TORCH_AVAILABLE:
msg = "PyTorch is required for GNNTrainer"
raise ImportError(msg)
self.config = config or TrainingConfig()
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.optimizer = AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
)
self._best_model_state = None
self._history: TrainingResult | None = None
def train(
self,
train_data: list[Any], # list[Data]
val_data: list[Any] | None = None,
verbose: bool = True,
) -> TrainingResult:
"""Train the model on a train/val dataset with early stopping."""
if not PYG_AVAILABLE:
msg = "PyTorch Geometric is required for training"
raise ImportError(msg)
train_loader = DataLoader(train_data, batch_size=self.config.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=self.config.batch_size) if val_data else None
train_losses = []
val_losses = []
best_val_loss = float("inf")
best_epoch = 0
patience_counter = 0
for epoch in range(self.config.epochs):
train_loss = self._train_epoch(train_loader)
train_losses.append(train_loss)
if val_loader:
val_loss = self._eval_epoch(val_loader)
val_losses.append(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
patience_counter = 0
self._best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
else:
patience_counter += 1
if patience_counter >= self.config.patience:
if verbose:
pass
break
if verbose and epoch % 10 == 0:
f", val_loss={val_losses[-1]:.4f}" if val_losses else ""
if self._best_model_state:
self.model.load_state_dict(self._best_model_state)
self._history = TrainingResult(
train_losses=train_losses,
val_losses=val_losses,
best_epoch=best_epoch,
best_val_loss=best_val_loss,
)
return self._history
def _train_epoch(self, loader: Any) -> float:
"""One training epoch, return the average loss."""
self.model.train()
total_loss = 0.0
for batch_data in loader:
batch_on_device = batch_data.to(self.device)
self.optimizer.zero_grad()
out = self.model(batch_on_device.x, batch_on_device.edge_index, getattr(batch_on_device, "edge_attr", None))
if self.config.task == "node_classification":
loss = functional.cross_entropy(out, batch_on_device.y)
elif self.config.task == "path_ranking":
loss = self._compute_ranking_loss(out, batch_on_device)
else:
loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float())
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def _eval_epoch(self, loader: Any) -> float:
"""Evaluate the model on the validation dataset, return average loss."""
self.model.eval()
total_loss = 0.0
with torch.no_grad():
for batch_data in loader:
batch_on_device = batch_data.to(self.device)
edge_attr = getattr(batch_on_device, "edge_attr", None)
out = self.model(batch_on_device.x, batch_on_device.edge_index, edge_attr)
if self.config.task == "node_classification":
loss = functional.cross_entropy(out, batch_on_device.y)
else:
loss = functional.mse_loss(out.squeeze(), batch_on_device.y.float())
total_loss += loss.item()
return total_loss / len(loader)
def _compute_ranking_loss(self, out: Any, batch: Any) -> Any:
"""Simplest MSE loss for path ranking."""
scores = out.squeeze()
targets = batch.y.float()
return functional.mse_loss(scores, targets)
def save(self, path: str | Path) -> None:
"""Save model weights and training history to a checkpoint."""
path = Path(path)
torch.save(
{
"model_state_dict": self.model.state_dict(),
"config": self.config,
"history": self._history,
},
path,
)
def load(self, path: str | Path) -> None:
"""Load a checkpoint and restore the model/configuration."""
path = Path(path)
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
if "config" in checkpoint:
self.config = checkpoint["config"]
if "history" in checkpoint:
self._history = checkpoint["history"]
class GNNRouterInference:
"""Inference module: generates features, runs GNN and selects routes."""
def __init__(
self,
model: Any,
feature_generator: FeatureGenerator | None = None,
device: str | None = None,
):
"""Prepare the model for inference on the given device."""
if not TORCH_AVAILABLE:
msg = "PyTorch is required for GNNRouterInference"
raise ImportError(msg)
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.model.eval()
self.feature_generator = feature_generator or DefaultFeatureGenerator()
def predict(
self,
graph: Any, # RoleGraph
source: str | None = None,
candidates: list[str] | None = None,
metrics_tracker: Any | None = None,
strategy: RoutingStrategy = RoutingStrategy.ARGMAX,
top_k: int = 3,
threshold: float = 0.5,
) -> RoutingPrediction:
"""Predict the next routing nodes according to the selected strategy."""
if not PYG_AVAILABLE:
msg = "PyTorch Geometric is required for inference"
raise ImportError(msg)
node_ids = graph.node_ids if hasattr(graph, "node_ids") else []
if not node_ids:
node_ids = [str(i) for i in range(graph.num_nodes)]
node_features = self.feature_generator.generate_node_features(graph, node_ids, metrics_tracker)
edge_index = graph.edge_index
# node_features and edge_index are already Tensors, just move to device
if isinstance(node_features, torch.Tensor):
x = node_features.to(dtype=torch.float32, device=self.device)
else:
x = torch.tensor(node_features, dtype=torch.float32, device=self.device)
if isinstance(edge_index, torch.Tensor):
edge_index_tensor = edge_index.to(dtype=torch.long, device=self.device)
else:
edge_index_tensor = torch.tensor(edge_index, dtype=torch.long, device=self.device)
with torch.no_grad():
out = self.model(x, edge_index_tensor)
# out shape: [N] or [N, C] where C = num_classes
if out.dim() > 1 and out.shape[-1] > 1:
# Multi-class: softmax over classes, take probability of the last (positive) class
probs = functional.softmax(out, dim=-1)
scores = probs[:, -1].cpu()
else:
scores = functional.softmax(out.squeeze(-1), dim=0).cpu()
node_scores = {node_ids[i]: float(scores[i].item()) for i in range(len(node_ids))}
if candidates:
node_scores = {k: v for k, v in node_scores.items() if k in candidates}
if source and source in node_scores:
del node_scores[source]
recommended = self._apply_strategy(node_scores, strategy, top_k, threshold)
confidence = torch.mean(torch.tensor([node_scores[n] for n in recommended])).item() if recommended else 0.0
return RoutingPrediction(
recommended_nodes=recommended,
scores=node_scores,
confidence=float(confidence),
strategy=strategy,
metadata={"source": source, "num_candidates": len(node_scores)},
)
def _apply_strategy(
self,
scores: dict[str, float],
strategy: RoutingStrategy,
top_k: int,
threshold: float,
) -> list[str]:
"""Select recommended nodes according to the given selection strategy."""
if not scores:
return []
if strategy == RoutingStrategy.ARGMAX:
best_node = max(scores.keys(), key=lambda k: scores[k])
return [best_node]
if strategy == RoutingStrategy.TOP_K:
sorted_nodes = sorted(scores.keys(), key=lambda k: scores[k], reverse=True)
return sorted_nodes[:top_k]
if strategy == RoutingStrategy.THRESHOLD:
return [n for n, s in scores.items() if s >= threshold]
if strategy == RoutingStrategy.SOFTMAX_SAMPLE:
nodes = list(scores.keys())
probs = torch.tensor([scores[n] for n in nodes])
probs = probs / probs.sum()
idx = int(torch.multinomial(probs, 1).item())
return [nodes[idx]]
return []
def get_all_scores(
self,
graph: Any,
metrics_tracker: Any | None = None,
) -> dict[str, float]:
"""Get scores for all graph nodes."""
prediction = self.predict(
graph,
metrics_tracker=metrics_tracker,
strategy=RoutingStrategy.TOP_K,
top_k=graph.num_nodes,
)
return prediction.scores
@classmethod
def load(cls, path: str | Path, feature_generator: FeatureGenerator | None = None) -> "GNNRouterInference":
"""Load the model and create an inference instance from a checkpoint."""
path = Path(path)
checkpoint = torch.load(path, map_location="cpu")
config = checkpoint.get("config", TrainingConfig())
model = GCNRouter(
in_channels=64,
hidden_channels=config.hidden_dim,
out_channels=config.num_classes,
num_layers=config.num_layers,
dropout=config.dropout,
)
model.load_state_dict(checkpoint["model_state_dict"])
return cls(model, feature_generator)
def create_gnn_router(
model_type: GNNModelType,
in_channels: int,
out_channels: int,
config: TrainingConfig | None = None,
) -> Any:
"""Factory: create a GNN router of the required type (GCN/GAT/SAGE)."""
if not TORCH_AVAILABLE or not PYG_AVAILABLE:
msg = "PyTorch and PyTorch Geometric are required"
raise ImportError(msg)
config = config or TrainingConfig()
kwargs = {
"in_channels": in_channels,
"hidden_channels": config.hidden_dim,
"out_channels": out_channels,
"num_layers": config.num_layers,
"dropout": config.dropout,
"use_batch_norm": config.use_batch_norm,
"use_residual": config.use_residual,
}
if model_type == GNNModelType.GCN:
return GCNRouter(**kwargs)
if model_type == GNNModelType.GAT:
return GATRouter(heads=int(config.gat_heads), **kwargs)
if model_type == GNNModelType.SAGE:
return GraphSAGERouter(**kwargs)
msg = f"Unknown model type: {model_type}"
raise ValueError(msg)
if not TORCH_AVAILABLE or not PYG_AVAILABLE:
class BaseGNNModel:
def __init__(self, *_args, **_kwargs):
msg = "PyTorch and PyTorch Geometric are required for GNN models"
raise ImportError(msg)
class GCNRouter(BaseGNNModel):
pass
class GATRouter(BaseGNNModel):
pass
class GraphSAGERouter(BaseGNNModel):
pass