| """One-at-a-time (OAT) sensitivity analysis for scenario parameters. |
| |
| For each parameter in ``params``, sweep it across [low, high] with |
| ``n_points`` evenly-spaced values (holding the others at their base |
| scenario values), call the engine's ``predict()``, and summarize how a |
| probe metric responds. |
| |
| Default probe metric is the temperature at cell (i,j) corresponding to |
| ``probe_x_m``/``probe_y_m``, but callers can also ask for reservoir mean |
| temperature. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
| from typing import Any, Callable |
|
|
| import numpy as np |
|
|
| from tools.predict_solver import predict as _solver_predict |
| from tools.predict_surrogate import predict as _surrogate_predict |
|
|
| ENGINES: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = { |
| "solver": _solver_predict, |
| "surrogate": _surrogate_predict, |
| } |
|
|
|
|
| def _cell_index(grid: dict[str, Any], x_m: float, y_m: float) -> tuple[int, int]: |
| i = int(round(x_m / grid["dx"] - 0.5)) |
| j = int(round(y_m / grid["dy"] - 0.5)) |
| i = int(np.clip(i, 0, grid["nx"] - 1)) |
| j = int(np.clip(j, 0, grid["ny"] - 1)) |
| return i, j |
|
|
|
|
| def _metric(result: dict[str, Any], metric: str, probe: tuple[int, int] | None) -> float: |
| t = np.asarray(result["temperature"]) |
| p = np.asarray(result["pressure"]) |
| if metric == "probe_temperature_C": |
| if probe is None: |
| msg = "probe_temperature_C requires probe_x_m/probe_y_m" |
| raise ValueError(msg) |
| return float(t[probe]) |
| if metric == "mean_temperature_C": |
| return float(t.mean()) |
| if metric == "min_temperature_C": |
| return float(t.min()) |
| if metric == "max_temperature_C": |
| return float(t.max()) |
| if metric == "mean_pressure_MPa": |
| return float(p.mean()) / 1.0e6 |
| msg = f"Unknown metric {metric!r}" |
| raise ValueError(msg) |
|
|
|
|
| def run( |
| scenario: dict[str, Any], |
| params: dict[str, dict[str, float]], |
| *, |
| engine: str = "surrogate", |
| n_points: int = 5, |
| metric: str = "probe_temperature_C", |
| probe_x_m: float | None = None, |
| probe_y_m: float | None = None, |
| ) -> dict[str, Any]: |
| """Run OAT sensitivity. |
| |
| Args: |
| scenario: base scenario dict. |
| params: name → {"low": ..., "high": ...} sweep range per parameter. |
| engine: ``"surrogate"`` or ``"solver"``. |
| n_points: samples per parameter (linspace). |
| metric: which scalar to track. One of |
| ``probe_temperature_C``, ``mean_temperature_C``, ``min_temperature_C``, |
| ``max_temperature_C``, ``mean_pressure_MPa``. |
| probe_x_m, probe_y_m: required when metric is ``probe_temperature_C``. |
| |
| Returns: |
| Dict with per-parameter sweep curves plus a ranking by |Δmetric|. |
| """ |
| if engine not in ENGINES: |
| msg = f"Unknown engine {engine!r}; choose from {sorted(ENGINES)}" |
| raise ValueError(msg) |
| predict_fn = ENGINES[engine] |
|
|
| |
| base_result = predict_fn(scenario) |
| probe: tuple[int, int] | None = None |
| if probe_x_m is not None and probe_y_m is not None: |
| probe = _cell_index(base_result["grid"], float(probe_x_m), float(probe_y_m)) |
| baseline_metric = _metric(base_result, metric, probe) |
|
|
| curves: dict[str, dict[str, Any]] = {} |
| rankings: list[dict[str, Any]] = [] |
| start = time.perf_counter() |
| for name, spec in params.items(): |
| low = float(spec["low"]) |
| high = float(spec["high"]) |
| xs = np.linspace(low, high, int(n_points)).tolist() |
| ys: list[float] = [] |
| for x in xs: |
| local = {**scenario, name: x} |
| r = predict_fn(local) |
| ys.append(_metric(r, metric, probe)) |
| delta = max(ys) - min(ys) |
| curves[name] = { |
| "values": xs, |
| "metric": ys, |
| "delta": float(delta), |
| "slope_per_unit": float((ys[-1] - ys[0]) / (xs[-1] - xs[0])) if xs[-1] != xs[0] else 0.0, |
| } |
| rankings.append({"param": name, "delta": float(delta)}) |
| elapsed = time.perf_counter() - start |
|
|
| rankings.sort(key=lambda r: r["delta"], reverse=True) |
|
|
| return { |
| "baseline_metric": float(baseline_metric), |
| "metric": metric, |
| "engine": engine, |
| "probe_cell": list(probe) if probe is not None else None, |
| "curves": curves, |
| "ranking": rankings, |
| "elapsed_seconds": elapsed, |
| } |
|
|