File size: 4,291 Bytes
7c3bfa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Backend selector: prefer the Rust extension, fall back to pure Python.

Both paths expose identical call signatures so the Dash app doesn't need
to care which one is running. ``backend_name()`` tells you which it is.
"""

from __future__ import annotations

from typing import List, Tuple

import numpy as np

try:  # pragma: no cover - import-time branching
    from . import _core as _rust  # type: ignore

    _BACKEND = "rust"
    _RUST_IMPORT_ERROR: str | None = None
except ImportError as _exc:
    _rust = None  # type: ignore
    _BACKEND = "python"
    # Keep the real exception message around so the launcher can log it.
    # ImportError on a Rust extension almost always means either the wheel
    # never installed into site-packages, or it installed but a runtime
    # dep (libc/libgomp/libpython ABI) can't be resolved. Silently falling
    # back to the Python reference on a production deploy hides real bugs.
    _RUST_IMPORT_ERROR = f"{type(_exc).__name__}: {_exc}"
    import sys as _sys
    print(
        f"[mountain-waves] Rust _core import failed, using Python fallback: "
        f"{_RUST_IMPORT_ERROR}",
        file=_sys.stderr,
        flush=True,
    )

from . import reference as _ref


def backend_name() -> str:
    """Return ``"rust"`` if the compiled extension is in use, else ``"python"``."""
    return _BACKEND


def _ensure_4tuple_two_layer(result, args):
    """Coerce older Rust binaries (returning 3-tuples) into the 4-tuple ABI.

    The Rust extension was extended to return u_prime alongside w. If the
    user hasn't rebuilt the .so yet, fall back to the Python reference for
    the full solve so u_prime is still correct rather than silently zero.
    """
    if isinstance(result, tuple) and len(result) == 4:
        return result
    # Stale binary — recompute via Python reference so u' is accurate.
    return _ref.compute_two_layer(*args)


def _ensure_4tuple_profile(result, args):
    if isinstance(result, tuple) and len(result) == 4:
        return result
    return _ref.compute_from_profile(*args)


def compute_two_layer(
    l_upper: float,
    l_lower: float,
    u: float,
    h: float,
    a: float,
    ho: float,
    xdom: float,
    zdom: float,
    mink: float,
    maxk: float,
    npts: int = 100,
):
    args = (l_upper, l_lower, u, h, a, ho, xdom, zdom, mink, maxk, npts)
    if _rust is not None:
        return _ensure_4tuple_two_layer(_rust.compute_two_layer(*args), args)
    return _ref.compute_two_layer(*args)


def compute_from_profile(
    z_profile,
    u_profile,
    theta_profile,
    a: float,
    ho: float,
    xdom: float,
    zdom: float,
    mink: float,
    maxk: float,
    npts: int = 100,
):
    zp = np.ascontiguousarray(np.asarray(z_profile, dtype=np.float64))
    up = np.ascontiguousarray(np.asarray(u_profile, dtype=np.float64))
    tp = np.ascontiguousarray(np.asarray(theta_profile, dtype=np.float64))
    args = (zp, up, tp, a, ho, xdom, zdom, mink, maxk, npts)
    if _rust is not None:
        return _ensure_4tuple_profile(_rust.compute_from_profile(*args), args)
    return _ref.compute_from_profile(*args)


def streamlines(x, z, u, w, num: int = 10) -> List[Tuple[np.ndarray, np.ndarray]]:
    """Trace ``num`` linearized streamlines through the ``w(x, z)`` field.

    ``u`` can be a scalar (uniform mean flow, as in the two-layer analytic)
    or a 1-D array of length ``nz`` giving the mean wind at each render-grid
    height. When an array is given we route through the Python tracer so the
    per-streamline advection speed is ``U(z₀)``, not ``U_surface``.
    """
    x = np.ascontiguousarray(np.asarray(x, dtype=np.float64))
    z = np.ascontiguousarray(np.asarray(z, dtype=np.float64))
    w = np.ascontiguousarray(np.asarray(w, dtype=np.float64))
    u_arr = np.atleast_1d(np.asarray(u, dtype=np.float64))
    # Rust streamlines() only accepts scalar u. For array-valued u (profile
    # mode with shear) dispatch to the Python tracer — it's only num*nx
    # floating-point adds, so the perf difference is negligible.
    if u_arr.size > 1:
        return _ref.streamlines(x, z, u_arr, w, num)
    u_scalar = float(u_arr[0])
    if _rust is not None:
        return _rust.streamlines(x, z, u_scalar, w, num)
    return _ref.streamlines(x, z, u_scalar, w, num)