""" Training Loop — Train the GNN on IEEE 33-bus load scenarios. """ from __future__ import annotations import os import time import torch from torch_geometric.loader import DataLoader from config import CFG from src.grid.loader import load_network from src.ai.model import build_model from src.ai.dataset import generate_scenarios from src.ai.physics_loss import DynamicLagrangeLoss def train( system: str = "case33bw", n_scenarios: int | None = None, epochs: int | None = None, batch_size: int | None = None, lr: float | None = None, device: str | None = None, save_path: str | None = None, verbose: bool = True, ) -> dict: """Train the GNN model. Parameters ---------- system : str – IEEE test system n_scenarios : int – number of load scenarios to generate epochs : int – training epochs batch_size : int lr : float – learning rate device : str – "cuda" or "cpu" save_path : str – path to save model checkpoint verbose : bool Returns ------- dict with training history and model path. """ cfg = CFG.ai n_scenarios = n_scenarios or cfg.n_scenarios epochs = epochs or cfg.epochs batch_size = batch_size or cfg.batch_size lr = lr or cfg.lr device = device or (cfg.device if torch.cuda.is_available() else "cpu") save_path = save_path or cfg.checkpoint_path if verbose: print(f"[Train] System: {system}, Scenarios: {n_scenarios}, " f"Epochs: {epochs}, Device: {device}") # --- Generate data --- t0 = time.perf_counter() net = load_network(system) if verbose: print(f"[Train] Generating {n_scenarios} load scenarios...") scenarios = generate_scenarios(net, n_scenarios=n_scenarios) if verbose: print(f"[Train] Generated {len(scenarios)} scenarios in " f"{time.perf_counter() - t0:.1f}s") if len(scenarios) < 10: return {"error": "Too few scenarios converged."} # Split: 80% train, 20% val split = int(0.8 * len(scenarios)) train_data = scenarios[:split] val_data = scenarios[split:] train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False) # --- Model --- model = build_model().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) loss_fn = DynamicLagrangeLoss(lambda_v_init=cfg.lambda_v, dual_lr=cfg.dual_lr) # --- Training --- history = [] best_val_loss = float("inf") for epoch in range(1, epochs + 1): model.train() train_loss_sum = 0.0 train_count = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch) losses = loss_fn(out["vm"], batch.y_vm.to(device)) losses["total"].backward() optimizer.step() train_loss_sum += losses["total"].item() * batch.num_graphs train_count += batch.num_graphs train_loss = train_loss_sum / max(train_count, 1) # Validation model.eval() val_loss_sum = 0.0 val_mse_sum = 0.0 val_count = 0 with torch.no_grad(): for batch in val_loader: batch = batch.to(device) out = model(batch) losses = loss_fn(out["vm"], batch.y_vm.to(device)) val_loss_sum += losses["total"].item() * batch.num_graphs val_mse_sum += losses["mse"].item() * batch.num_graphs val_count += batch.num_graphs val_loss = val_loss_sum / max(val_count, 1) val_mse = val_mse_sum / max(val_count, 1) history.append({ "epoch": epoch, "train_loss": round(train_loss, 6), "val_loss": round(val_loss, 6), "val_mse": round(val_mse, 6), "lambda_v": round(loss_fn.lambda_v, 4), }) if val_loss < best_val_loss: best_val_loss = val_loss os.makedirs(os.path.dirname(save_path), exist_ok=True) torch.save(model.state_dict(), save_path) if verbose and (epoch % 20 == 0 or epoch == 1): print(f" Epoch {epoch:3d}: train={train_loss:.6f} val={val_loss:.6f} " f"mse={val_mse:.6f} λ_v={loss_fn.lambda_v:.2f}") if verbose: print(f"[Train] Done. Best val loss: {best_val_loss:.6f}") print(f"[Train] Model saved to {save_path}") return { "history": history, "best_val_loss": best_val_loss, "model_path": save_path, "n_train": len(train_data), "n_val": len(val_data), } if __name__ == "__main__": import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) result = train(n_scenarios=500, epochs=100, verbose=True) if "error" in result: print(f"ERROR: {result['error']}")