Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import socket | |
| import sys | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, List, Mapping, MutableMapping, Protocol | |
| import numpy as np | |
| import pandas as pd | |
| DEFAULT_TORCH_NUM_THREADS = 2 | |
| DEFAULT_TORCH_NUM_INTEROP_THREADS = 1 | |
| DEFAULT_OMP_NUM_THREADS = 2 | |
| DEFAULT_MKL_NUM_THREADS = 2 | |
| ALLOW_NON_VENV_PYTHON_ENV_KEY = "AIFORECAST_ALLOW_NON_VENV_PYTHON" | |
| class RuntimePaths: | |
| current_dir: str | |
| project_root: str | |
| class TorchThreadSettings: | |
| torch_num_threads: int | |
| torch_num_interop_threads: int | |
| omp_num_threads: int | |
| mkl_num_threads: int | |
| class LoggerLike(Protocol): | |
| def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ... | |
| def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ... | |
| def _read_positive_int( | |
| raw_value: object, | |
| *, | |
| default: int, | |
| ) -> int: | |
| try: | |
| value = int(str(raw_value).strip()) | |
| except (TypeError, ValueError, AttributeError): | |
| return default | |
| return max(1, value) | |
| def _read_bool(raw_value: object, *, default: bool) -> bool: | |
| normalized = str(raw_value or "").strip().lower() | |
| if not normalized: | |
| return default | |
| if normalized in {"1", "true", "yes", "on"}: | |
| return True | |
| if normalized in {"0", "false", "no", "off"}: | |
| return False | |
| return default | |
| def read_torch_thread_settings( | |
| env: Mapping[str, str] | None = None, | |
| ) -> TorchThreadSettings: | |
| runtime_env = os.environ if env is None else env | |
| return TorchThreadSettings( | |
| torch_num_threads=_read_positive_int( | |
| runtime_env.get("TORCH_NUM_THREADS"), | |
| default=DEFAULT_TORCH_NUM_THREADS, | |
| ), | |
| torch_num_interop_threads=_read_positive_int( | |
| runtime_env.get("TORCH_NUM_INTEROP_THREADS"), | |
| default=DEFAULT_TORCH_NUM_INTEROP_THREADS, | |
| ), | |
| omp_num_threads=_read_positive_int( | |
| runtime_env.get("OMP_NUM_THREADS"), | |
| default=DEFAULT_OMP_NUM_THREADS, | |
| ), | |
| mkl_num_threads=_read_positive_int( | |
| runtime_env.get("MKL_NUM_THREADS"), | |
| default=DEFAULT_MKL_NUM_THREADS, | |
| ), | |
| ) | |
| def _expected_project_venv_python(project_root: Path) -> Path: | |
| if os.name == "nt": | |
| return project_root / "venv" / "Scripts" / "python.exe" | |
| return project_root / "venv" / "bin" / "python" | |
| def _same_path(left: Path, right: Path) -> bool: | |
| try: | |
| if left.exists() and right.exists(): | |
| return left.samefile(right) | |
| except OSError: | |
| pass | |
| return str(left.resolve()).lower() == str(right.resolve()).lower() | |
| def describe_runtime_environment( | |
| project_root: str | Path, | |
| *, | |
| env: Mapping[str, str] | None = None, | |
| current_executable: str | None = None, | |
| ) -> dict[str, Any]: | |
| runtime_env = os.environ if env is None else env | |
| root_path = Path(project_root).resolve() | |
| current_python = Path(current_executable or sys.executable).resolve() | |
| expected_python = _expected_project_venv_python(root_path) | |
| allow_non_venv_python = _read_bool( | |
| runtime_env.get(ALLOW_NON_VENV_PYTHON_ENV_KEY), | |
| default=False, | |
| ) | |
| expected_exists = expected_python.exists() | |
| is_project_venv_python = not expected_exists or _same_path(current_python, expected_python) | |
| settings = read_torch_thread_settings(runtime_env) | |
| return { | |
| "current_executable": str(current_python), | |
| "expected_venv_python": str(expected_python), | |
| "expected_venv_exists": expected_exists, | |
| "is_project_venv_python": is_project_venv_python, | |
| "allow_non_venv_python": allow_non_venv_python, | |
| "thread_caps": { | |
| "torch_num_threads": settings.torch_num_threads, | |
| "torch_num_interop_threads": settings.torch_num_interop_threads, | |
| "omp_num_threads": settings.omp_num_threads, | |
| "mkl_num_threads": settings.mkl_num_threads, | |
| }, | |
| } | |
| def prepare_runtime_environment( | |
| project_root: str | Path, | |
| *, | |
| env: MutableMapping[str, str] | None = None, | |
| current_executable: str | None = None, | |
| logger: LoggerLike | None = None, | |
| ) -> TorchThreadSettings: | |
| runtime_env = os.environ if env is None else env | |
| settings = read_torch_thread_settings(runtime_env) | |
| runtime_env.setdefault("TORCH_NUM_THREADS", str(settings.torch_num_threads)) | |
| runtime_env.setdefault("TORCH_NUM_INTEROP_THREADS", str(settings.torch_num_interop_threads)) | |
| runtime_env.setdefault("OMP_NUM_THREADS", str(settings.omp_num_threads)) | |
| runtime_env.setdefault("MKL_NUM_THREADS", str(settings.mkl_num_threads)) | |
| snapshot = describe_runtime_environment( | |
| project_root, | |
| env=runtime_env, | |
| current_executable=current_executable, | |
| ) | |
| if ( | |
| snapshot["expected_venv_exists"] | |
| and not snapshot["is_project_venv_python"] | |
| and not snapshot["allow_non_venv_python"] | |
| ): | |
| raise RuntimeError( | |
| "Current Python interpreter is outside the project venv. " | |
| f"Expected {snapshot['expected_venv_python']} but got {snapshot['current_executable']}. " | |
| f"Set {ALLOW_NON_VENV_PYTHON_ENV_KEY}=true to override." | |
| ) | |
| if logger is not None: | |
| logger.info( | |
| "Prepared runtime thread caps: torch=%s interop=%s omp=%s mkl=%s", | |
| settings.torch_num_threads, | |
| settings.torch_num_interop_threads, | |
| settings.omp_num_threads, | |
| settings.mkl_num_threads, | |
| ) | |
| return settings | |
| def apply_torch_thread_settings( | |
| torch_module: Any | None, | |
| env: Mapping[str, str] | None = None, | |
| *, | |
| logger: LoggerLike | None = None, | |
| ) -> dict[str, Any]: | |
| settings = read_torch_thread_settings(env) | |
| applied = False | |
| if torch_module is not None: | |
| set_num_threads = getattr(torch_module, "set_num_threads", None) | |
| if callable(set_num_threads): | |
| set_num_threads(settings.torch_num_threads) | |
| applied = True | |
| set_num_interop_threads = getattr(torch_module, "set_num_interop_threads", None) | |
| if callable(set_num_interop_threads): | |
| try: | |
| set_num_interop_threads(settings.torch_num_interop_threads) | |
| except RuntimeError as exc: | |
| if logger is not None: | |
| logger.warning("Torch interop thread cap could not be updated: %s", exc) | |
| else: | |
| applied = True | |
| if logger is not None: | |
| logger.info( | |
| "Applied torch thread caps: torch=%s interop=%s omp=%s mkl=%s applied=%s", | |
| settings.torch_num_threads, | |
| settings.torch_num_interop_threads, | |
| settings.omp_num_threads, | |
| settings.mkl_num_threads, | |
| applied, | |
| ) | |
| return { | |
| "torch_num_threads": settings.torch_num_threads, | |
| "torch_num_interop_threads": settings.torch_num_interop_threads, | |
| "omp_num_threads": settings.omp_num_threads, | |
| "mkl_num_threads": settings.mkl_num_threads, | |
| "applied": applied, | |
| } | |
| def can_bind_tcp_port(host: str, port: int) -> bool: | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: | |
| try: | |
| server_socket.bind((host, port)) | |
| except OSError: | |
| return False | |
| return True | |
| def find_free_tcp_port(host: str = "127.0.0.1") -> int: | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: | |
| server_socket.bind((host, 0)) | |
| server_socket.listen(1) | |
| return int(server_socket.getsockname()[1]) | |
| def parse_cors_origins(raw_origins: str) -> List[str]: | |
| origins = [origin.strip() for origin in raw_origins.split(",") if origin.strip()] | |
| return origins or ["*"] | |
| def clone_cache_payload(payload: Any) -> Any: | |
| if payload is None or isinstance(payload, (str, int, float, bool)): | |
| return payload | |
| if isinstance(payload, list): | |
| return [clone_cache_payload(item) for item in payload] | |
| if isinstance(payload, tuple): | |
| return tuple(clone_cache_payload(item) for item in payload) | |
| if isinstance(payload, dict): | |
| return {clone_cache_payload(key): clone_cache_payload(value) for key, value in payload.items()} | |
| if isinstance(payload, set): | |
| return {clone_cache_payload(item) for item in payload} | |
| try: | |
| return deepcopy(payload) | |
| except Exception: | |
| return payload | |
| def make_json_compatible(value: Any) -> Any: | |
| if isinstance(value, dict): | |
| return {str(key): make_json_compatible(item) for key, item in value.items()} | |
| if isinstance(value, (list, tuple, set)): | |
| return [make_json_compatible(item) for item in value] | |
| if isinstance(value, np.ndarray): | |
| return [make_json_compatible(item) for item in value.tolist()] | |
| if isinstance(value, np.generic): | |
| return make_json_compatible(value.item()) | |
| if isinstance(value, pd.Timestamp): | |
| return value.isoformat() | |
| if isinstance(value, float): | |
| if not np.isfinite(value): | |
| return None | |
| return value | |
| return value | |
| def resolve_runtime_paths( | |
| module_file: str, | |
| is_frozen: bool, | |
| bundle_dir: str, | |
| ) -> RuntimePaths: | |
| current_dir = os.path.dirname(os.path.abspath(module_file)) | |
| if is_frozen: | |
| return RuntimePaths(current_dir=bundle_dir, project_root=bundle_dir) | |
| return RuntimePaths( | |
| current_dir=current_dir, | |
| project_root=os.path.dirname(current_dir), | |
| ) | |