OptiQ / src /ai /train.py
AhmedSamir1598's picture
first baseline for project OptiQ. Contains research resources, first baseline using GNNs + QC, and benchmarks against current industry standards, while addressing the challenges that prevents better practices to be used in industry.
55e3496
"""
Training Loop — Train the GNN on IEEE 33-bus load scenarios.
"""
from __future__ import annotations
import os
import time
import torch
from torch_geometric.loader import DataLoader
from config import CFG
from src.grid.loader import load_network
from src.ai.model import build_model
from src.ai.dataset import generate_scenarios
from src.ai.physics_loss import DynamicLagrangeLoss
def train(
system: str = "case33bw",
n_scenarios: int | None = None,
epochs: int | None = None,
batch_size: int | None = None,
lr: float | None = None,
device: str | None = None,
save_path: str | None = None,
verbose: bool = True,
) -> dict:
"""Train the GNN model.
Parameters
----------
system : str – IEEE test system
n_scenarios : int – number of load scenarios to generate
epochs : int – training epochs
batch_size : int
lr : float – learning rate
device : str – "cuda" or "cpu"
save_path : str – path to save model checkpoint
verbose : bool
Returns
-------
dict with training history and model path.
"""
cfg = CFG.ai
n_scenarios = n_scenarios or cfg.n_scenarios
epochs = epochs or cfg.epochs
batch_size = batch_size or cfg.batch_size
lr = lr or cfg.lr
device = device or (cfg.device if torch.cuda.is_available() else "cpu")
save_path = save_path or cfg.checkpoint_path
if verbose:
print(f"[Train] System: {system}, Scenarios: {n_scenarios}, "
f"Epochs: {epochs}, Device: {device}")
# --- Generate data ---
t0 = time.perf_counter()
net = load_network(system)
if verbose:
print(f"[Train] Generating {n_scenarios} load scenarios...")
scenarios = generate_scenarios(net, n_scenarios=n_scenarios)
if verbose:
print(f"[Train] Generated {len(scenarios)} scenarios in "
f"{time.perf_counter() - t0:.1f}s")
if len(scenarios) < 10:
return {"error": "Too few scenarios converged."}
# Split: 80% train, 20% val
split = int(0.8 * len(scenarios))
train_data = scenarios[:split]
val_data = scenarios[split:]
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
# --- Model ---
model = build_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = DynamicLagrangeLoss(lambda_v_init=cfg.lambda_v, dual_lr=cfg.dual_lr)
# --- Training ---
history = []
best_val_loss = float("inf")
for epoch in range(1, epochs + 1):
model.train()
train_loss_sum = 0.0
train_count = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch)
losses = loss_fn(out["vm"], batch.y_vm.to(device))
losses["total"].backward()
optimizer.step()
train_loss_sum += losses["total"].item() * batch.num_graphs
train_count += batch.num_graphs
train_loss = train_loss_sum / max(train_count, 1)
# Validation
model.eval()
val_loss_sum = 0.0
val_mse_sum = 0.0
val_count = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
out = model(batch)
losses = loss_fn(out["vm"], batch.y_vm.to(device))
val_loss_sum += losses["total"].item() * batch.num_graphs
val_mse_sum += losses["mse"].item() * batch.num_graphs
val_count += batch.num_graphs
val_loss = val_loss_sum / max(val_count, 1)
val_mse = val_mse_sum / max(val_count, 1)
history.append({
"epoch": epoch,
"train_loss": round(train_loss, 6),
"val_loss": round(val_loss, 6),
"val_mse": round(val_mse, 6),
"lambda_v": round(loss_fn.lambda_v, 4),
})
if val_loss < best_val_loss:
best_val_loss = val_loss
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)
if verbose and (epoch % 20 == 0 or epoch == 1):
print(f" Epoch {epoch:3d}: train={train_loss:.6f} val={val_loss:.6f} "
f"mse={val_mse:.6f} λ_v={loss_fn.lambda_v:.2f}")
if verbose:
print(f"[Train] Done. Best val loss: {best_val_loss:.6f}")
print(f"[Train] Model saved to {save_path}")
return {
"history": history,
"best_val_loss": best_val_loss,
"model_path": save_path,
"n_train": len(train_data),
"n_val": len(val_data),
}
if __name__ == "__main__":
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
result = train(n_scenarios=500, epochs=100, verbose=True)
if "error" in result:
print(f"ERROR: {result['error']}")