from __future__ import annotations import os import socket import sys from copy import deepcopy from dataclasses import dataclass from pathlib import Path from typing import Any, List, Mapping, MutableMapping, Protocol import numpy as np import pandas as pd DEFAULT_TORCH_NUM_THREADS = 2 DEFAULT_TORCH_NUM_INTEROP_THREADS = 1 DEFAULT_OMP_NUM_THREADS = 2 DEFAULT_MKL_NUM_THREADS = 2 ALLOW_NON_VENV_PYTHON_ENV_KEY = "AIFORECAST_ALLOW_NON_VENV_PYTHON" @dataclass(frozen=True) class RuntimePaths: current_dir: str project_root: str @dataclass(frozen=True) class TorchThreadSettings: torch_num_threads: int torch_num_interop_threads: int omp_num_threads: int mkl_num_threads: int class LoggerLike(Protocol): def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ... def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ... def _read_positive_int( raw_value: object, *, default: int, ) -> int: try: value = int(str(raw_value).strip()) except (TypeError, ValueError, AttributeError): return default return max(1, value) def _read_bool(raw_value: object, *, default: bool) -> bool: normalized = str(raw_value or "").strip().lower() if not normalized: return default if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False return default def read_torch_thread_settings( env: Mapping[str, str] | None = None, ) -> TorchThreadSettings: runtime_env = os.environ if env is None else env return TorchThreadSettings( torch_num_threads=_read_positive_int( runtime_env.get("TORCH_NUM_THREADS"), default=DEFAULT_TORCH_NUM_THREADS, ), torch_num_interop_threads=_read_positive_int( runtime_env.get("TORCH_NUM_INTEROP_THREADS"), default=DEFAULT_TORCH_NUM_INTEROP_THREADS, ), omp_num_threads=_read_positive_int( runtime_env.get("OMP_NUM_THREADS"), default=DEFAULT_OMP_NUM_THREADS, ), mkl_num_threads=_read_positive_int( runtime_env.get("MKL_NUM_THREADS"), default=DEFAULT_MKL_NUM_THREADS, ), ) def _expected_project_venv_python(project_root: Path) -> Path: if os.name == "nt": return project_root / "venv" / "Scripts" / "python.exe" return project_root / "venv" / "bin" / "python" def _same_path(left: Path, right: Path) -> bool: try: if left.exists() and right.exists(): return left.samefile(right) except OSError: pass return str(left.resolve()).lower() == str(right.resolve()).lower() def describe_runtime_environment( project_root: str | Path, *, env: Mapping[str, str] | None = None, current_executable: str | None = None, ) -> dict[str, Any]: runtime_env = os.environ if env is None else env root_path = Path(project_root).resolve() current_python = Path(current_executable or sys.executable).resolve() expected_python = _expected_project_venv_python(root_path) allow_non_venv_python = _read_bool( runtime_env.get(ALLOW_NON_VENV_PYTHON_ENV_KEY), default=False, ) expected_exists = expected_python.exists() is_project_venv_python = not expected_exists or _same_path(current_python, expected_python) settings = read_torch_thread_settings(runtime_env) return { "current_executable": str(current_python), "expected_venv_python": str(expected_python), "expected_venv_exists": expected_exists, "is_project_venv_python": is_project_venv_python, "allow_non_venv_python": allow_non_venv_python, "thread_caps": { "torch_num_threads": settings.torch_num_threads, "torch_num_interop_threads": settings.torch_num_interop_threads, "omp_num_threads": settings.omp_num_threads, "mkl_num_threads": settings.mkl_num_threads, }, } def prepare_runtime_environment( project_root: str | Path, *, env: MutableMapping[str, str] | None = None, current_executable: str | None = None, logger: LoggerLike | None = None, ) -> TorchThreadSettings: runtime_env = os.environ if env is None else env settings = read_torch_thread_settings(runtime_env) runtime_env.setdefault("TORCH_NUM_THREADS", str(settings.torch_num_threads)) runtime_env.setdefault("TORCH_NUM_INTEROP_THREADS", str(settings.torch_num_interop_threads)) runtime_env.setdefault("OMP_NUM_THREADS", str(settings.omp_num_threads)) runtime_env.setdefault("MKL_NUM_THREADS", str(settings.mkl_num_threads)) snapshot = describe_runtime_environment( project_root, env=runtime_env, current_executable=current_executable, ) if ( snapshot["expected_venv_exists"] and not snapshot["is_project_venv_python"] and not snapshot["allow_non_venv_python"] ): raise RuntimeError( "Current Python interpreter is outside the project venv. " f"Expected {snapshot['expected_venv_python']} but got {snapshot['current_executable']}. " f"Set {ALLOW_NON_VENV_PYTHON_ENV_KEY}=true to override." ) if logger is not None: logger.info( "Prepared runtime thread caps: torch=%s interop=%s omp=%s mkl=%s", settings.torch_num_threads, settings.torch_num_interop_threads, settings.omp_num_threads, settings.mkl_num_threads, ) return settings def apply_torch_thread_settings( torch_module: Any | None, env: Mapping[str, str] | None = None, *, logger: LoggerLike | None = None, ) -> dict[str, Any]: settings = read_torch_thread_settings(env) applied = False if torch_module is not None: set_num_threads = getattr(torch_module, "set_num_threads", None) if callable(set_num_threads): set_num_threads(settings.torch_num_threads) applied = True set_num_interop_threads = getattr(torch_module, "set_num_interop_threads", None) if callable(set_num_interop_threads): try: set_num_interop_threads(settings.torch_num_interop_threads) except RuntimeError as exc: if logger is not None: logger.warning("Torch interop thread cap could not be updated: %s", exc) else: applied = True if logger is not None: logger.info( "Applied torch thread caps: torch=%s interop=%s omp=%s mkl=%s applied=%s", settings.torch_num_threads, settings.torch_num_interop_threads, settings.omp_num_threads, settings.mkl_num_threads, applied, ) return { "torch_num_threads": settings.torch_num_threads, "torch_num_interop_threads": settings.torch_num_interop_threads, "omp_num_threads": settings.omp_num_threads, "mkl_num_threads": settings.mkl_num_threads, "applied": applied, } def can_bind_tcp_port(host: str, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: try: server_socket.bind((host, port)) except OSError: return False return True def find_free_tcp_port(host: str = "127.0.0.1") -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: server_socket.bind((host, 0)) server_socket.listen(1) return int(server_socket.getsockname()[1]) def parse_cors_origins(raw_origins: str) -> List[str]: origins = [origin.strip() for origin in raw_origins.split(",") if origin.strip()] return origins or ["*"] def clone_cache_payload(payload: Any) -> Any: if payload is None or isinstance(payload, (str, int, float, bool)): return payload if isinstance(payload, list): return [clone_cache_payload(item) for item in payload] if isinstance(payload, tuple): return tuple(clone_cache_payload(item) for item in payload) if isinstance(payload, dict): return {clone_cache_payload(key): clone_cache_payload(value) for key, value in payload.items()} if isinstance(payload, set): return {clone_cache_payload(item) for item in payload} try: return deepcopy(payload) except Exception: return payload def make_json_compatible(value: Any) -> Any: if isinstance(value, dict): return {str(key): make_json_compatible(item) for key, item in value.items()} if isinstance(value, (list, tuple, set)): return [make_json_compatible(item) for item in value] if isinstance(value, np.ndarray): return [make_json_compatible(item) for item in value.tolist()] if isinstance(value, np.generic): return make_json_compatible(value.item()) if isinstance(value, pd.Timestamp): return value.isoformat() if isinstance(value, float): if not np.isfinite(value): return None return value return value def resolve_runtime_paths( module_file: str, is_frozen: bool, bundle_dir: str, ) -> RuntimePaths: current_dir = os.path.dirname(os.path.abspath(module_file)) if is_frozen: return RuntimePaths(current_dir=bundle_dir, project_root=bundle_dir) return RuntimePaths( current_dir=current_dir, project_root=os.path.dirname(current_dir), )