File size: 4,426 Bytes
97f04e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Unified `predict(scenario)` wrapper around the v1.1 CNN surrogate.

Returns the same schema as :func:`tools.predict_solver.predict` so the
agent can swap engines without changing downstream plumbing.

The surrogate is a fixed 32x32 grid trained on a specific physical domain.
The scenario dict may supply `dx`/`dy` as grid metadata for downstream
coordinate mapping, but the CNN itself always runs on 32x32.
"""

from __future__ import annotations

import time
from typing import Any

import numpy as np

from surrogate.predict import predict as _cnn_predict

SURROGATE_NX = 32
SURROGATE_NY = 32

DEFAULTS: dict[str, Any] = {
    "dx": 12.5,  # meters; surrogate native domain is 400m x 400m
    "dy": 12.5,
    "porosity": 0.12,
    "log_permeability": -12.3,  # log10(5e-13)
    "base_pressure": 1.5e7,
    "depth": 1500.0,
    "T_initial": 200.0,
    "wells": [],
}


def _rescale_well_indices(
    wells: list[dict[str, Any]],
    *,
    source_nx: int | None,
    source_ny: int | None,
) -> list[tuple[int, int]]:
    """Map (i, j) cell indices from an arbitrary grid onto the surrogate's 32x32.

    Interprets (i, j) as (row, col) in the source grid. If source_nx/source_ny
    is None, assumes the well indices are already in 32x32 surrogate space.
    """
    out: list[tuple[int, int]] = []
    for w in wells:
        i = int(w["i"])
        j = int(w["j"])
        if source_nx is not None and source_ny is not None:
            i_s = int(round(i * (SURROGATE_NX - 1) / max(source_nx - 1, 1)))
            j_s = int(round(j * (SURROGATE_NY - 1) / max(source_ny - 1, 1)))
        else:
            i_s, j_s = i, j
        i_s = int(np.clip(i_s, 0, SURROGATE_NX - 1))
        j_s = int(np.clip(j_s, 0, SURROGATE_NY - 1))
        out.append((i_s, j_s))
    return out


def predict(scenario: dict[str, Any]) -> dict[str, Any]:
    """Run the v1.1 CNN surrogate on a scenario dict.

    Args:
        scenario: plain dict. Recognised keys:
            - ``T_initial`` (scalar degC) OR ``initial_temperature`` ((32,32) array)
            - ``log_permeability`` (float)
            - ``permeability`` (float, converted via log10 if ``log_permeability`` absent)
            - ``base_pressure`` (Pa)
            - ``porosity``, ``depth``
            - ``wells``: list of {i, j, ...}; indices rescaled from (nx, ny) → 32x32.
            - ``nx``, ``ny``: source-grid dims used to rescale well indices.
            - ``dx``, ``dy``: optional, just used for grid metadata in the output.

    Returns:
        Same schema as predict_solver: temperature/pressure (32, 32) at the
        final time step + grid metadata + engine + elapsed_seconds.
    """
    cfg = {**DEFAULTS, **scenario}

    if "initial_temperature" in cfg:
        t_init = np.asarray(cfg["initial_temperature"], dtype=np.float32)
        if t_init.shape != (SURROGATE_NX, SURROGATE_NY):
            msg = f"initial_temperature shape {t_init.shape} != ({SURROGATE_NX}, {SURROGATE_NY})"
            raise ValueError(msg)
    else:
        t_init = np.full((SURROGATE_NX, SURROGATE_NY), float(cfg["T_initial"]), dtype=np.float32)

    if "log_permeability" in scenario:
        log_k = float(scenario["log_permeability"])
    elif "permeability" in scenario:
        log_k = float(np.log10(float(scenario["permeability"])))
    else:
        log_k = float(cfg["log_permeability"])

    wells_raw = list(cfg["wells"])
    well_locations = _rescale_well_indices(
        wells_raw,
        source_nx=int(cfg["nx"]) if "nx" in cfg else None,
        source_ny=int(cfg["ny"]) if "ny" in cfg else None,
    )

    start = time.perf_counter()
    out = _cnn_predict(
        initial_temperature=t_init,
        log_permeability=log_k,
        well_locations=well_locations,
        base_pressure=float(cfg["base_pressure"]),
        porosity=float(cfg["porosity"]),
        depth=float(cfg["depth"]),
    )
    elapsed = time.perf_counter() - start

    t_final = out["temperature"][-1]
    p_final = out["pressure"][-1]

    dx = float(cfg["dx"])
    dy = float(cfg["dy"])
    return {
        "temperature": t_final,
        "pressure": p_final,
        "grid": {
            "nx": SURROGATE_NX,
            "ny": SURROGATE_NY,
            "dx": dx,
            "dy": dy,
            "extent_x": [0.0, SURROGATE_NX * dx],
            "extent_y": [0.0, SURROGATE_NY * dy],
        },
        "engine": "surrogate",
        "elapsed_seconds": elapsed,
    }