SuperAI_Forecast / backend /runtime_utils.py
Thang6822
Update HF Space deployment
4106e0f
from __future__ import annotations
import os
import socket
import sys
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Mapping, MutableMapping, Protocol
import numpy as np
import pandas as pd
DEFAULT_TORCH_NUM_THREADS = 2
DEFAULT_TORCH_NUM_INTEROP_THREADS = 1
DEFAULT_OMP_NUM_THREADS = 2
DEFAULT_MKL_NUM_THREADS = 2
ALLOW_NON_VENV_PYTHON_ENV_KEY = "AIFORECAST_ALLOW_NON_VENV_PYTHON"
@dataclass(frozen=True)
class RuntimePaths:
current_dir: str
project_root: str
@dataclass(frozen=True)
class TorchThreadSettings:
torch_num_threads: int
torch_num_interop_threads: int
omp_num_threads: int
mkl_num_threads: int
class LoggerLike(Protocol):
def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
def _read_positive_int(
raw_value: object,
*,
default: int,
) -> int:
try:
value = int(str(raw_value).strip())
except (TypeError, ValueError, AttributeError):
return default
return max(1, value)
def _read_bool(raw_value: object, *, default: bool) -> bool:
normalized = str(raw_value or "").strip().lower()
if not normalized:
return default
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def read_torch_thread_settings(
env: Mapping[str, str] | None = None,
) -> TorchThreadSettings:
runtime_env = os.environ if env is None else env
return TorchThreadSettings(
torch_num_threads=_read_positive_int(
runtime_env.get("TORCH_NUM_THREADS"),
default=DEFAULT_TORCH_NUM_THREADS,
),
torch_num_interop_threads=_read_positive_int(
runtime_env.get("TORCH_NUM_INTEROP_THREADS"),
default=DEFAULT_TORCH_NUM_INTEROP_THREADS,
),
omp_num_threads=_read_positive_int(
runtime_env.get("OMP_NUM_THREADS"),
default=DEFAULT_OMP_NUM_THREADS,
),
mkl_num_threads=_read_positive_int(
runtime_env.get("MKL_NUM_THREADS"),
default=DEFAULT_MKL_NUM_THREADS,
),
)
def _expected_project_venv_python(project_root: Path) -> Path:
if os.name == "nt":
return project_root / "venv" / "Scripts" / "python.exe"
return project_root / "venv" / "bin" / "python"
def _same_path(left: Path, right: Path) -> bool:
try:
if left.exists() and right.exists():
return left.samefile(right)
except OSError:
pass
return str(left.resolve()).lower() == str(right.resolve()).lower()
def describe_runtime_environment(
project_root: str | Path,
*,
env: Mapping[str, str] | None = None,
current_executable: str | None = None,
) -> dict[str, Any]:
runtime_env = os.environ if env is None else env
root_path = Path(project_root).resolve()
current_python = Path(current_executable or sys.executable).resolve()
expected_python = _expected_project_venv_python(root_path)
allow_non_venv_python = _read_bool(
runtime_env.get(ALLOW_NON_VENV_PYTHON_ENV_KEY),
default=False,
)
expected_exists = expected_python.exists()
is_project_venv_python = not expected_exists or _same_path(current_python, expected_python)
settings = read_torch_thread_settings(runtime_env)
return {
"current_executable": str(current_python),
"expected_venv_python": str(expected_python),
"expected_venv_exists": expected_exists,
"is_project_venv_python": is_project_venv_python,
"allow_non_venv_python": allow_non_venv_python,
"thread_caps": {
"torch_num_threads": settings.torch_num_threads,
"torch_num_interop_threads": settings.torch_num_interop_threads,
"omp_num_threads": settings.omp_num_threads,
"mkl_num_threads": settings.mkl_num_threads,
},
}
def prepare_runtime_environment(
project_root: str | Path,
*,
env: MutableMapping[str, str] | None = None,
current_executable: str | None = None,
logger: LoggerLike | None = None,
) -> TorchThreadSettings:
runtime_env = os.environ if env is None else env
settings = read_torch_thread_settings(runtime_env)
runtime_env.setdefault("TORCH_NUM_THREADS", str(settings.torch_num_threads))
runtime_env.setdefault("TORCH_NUM_INTEROP_THREADS", str(settings.torch_num_interop_threads))
runtime_env.setdefault("OMP_NUM_THREADS", str(settings.omp_num_threads))
runtime_env.setdefault("MKL_NUM_THREADS", str(settings.mkl_num_threads))
snapshot = describe_runtime_environment(
project_root,
env=runtime_env,
current_executable=current_executable,
)
if (
snapshot["expected_venv_exists"]
and not snapshot["is_project_venv_python"]
and not snapshot["allow_non_venv_python"]
):
raise RuntimeError(
"Current Python interpreter is outside the project venv. "
f"Expected {snapshot['expected_venv_python']} but got {snapshot['current_executable']}. "
f"Set {ALLOW_NON_VENV_PYTHON_ENV_KEY}=true to override."
)
if logger is not None:
logger.info(
"Prepared runtime thread caps: torch=%s interop=%s omp=%s mkl=%s",
settings.torch_num_threads,
settings.torch_num_interop_threads,
settings.omp_num_threads,
settings.mkl_num_threads,
)
return settings
def apply_torch_thread_settings(
torch_module: Any | None,
env: Mapping[str, str] | None = None,
*,
logger: LoggerLike | None = None,
) -> dict[str, Any]:
settings = read_torch_thread_settings(env)
applied = False
if torch_module is not None:
set_num_threads = getattr(torch_module, "set_num_threads", None)
if callable(set_num_threads):
set_num_threads(settings.torch_num_threads)
applied = True
set_num_interop_threads = getattr(torch_module, "set_num_interop_threads", None)
if callable(set_num_interop_threads):
try:
set_num_interop_threads(settings.torch_num_interop_threads)
except RuntimeError as exc:
if logger is not None:
logger.warning("Torch interop thread cap could not be updated: %s", exc)
else:
applied = True
if logger is not None:
logger.info(
"Applied torch thread caps: torch=%s interop=%s omp=%s mkl=%s applied=%s",
settings.torch_num_threads,
settings.torch_num_interop_threads,
settings.omp_num_threads,
settings.mkl_num_threads,
applied,
)
return {
"torch_num_threads": settings.torch_num_threads,
"torch_num_interop_threads": settings.torch_num_interop_threads,
"omp_num_threads": settings.omp_num_threads,
"mkl_num_threads": settings.mkl_num_threads,
"applied": applied,
}
def can_bind_tcp_port(host: str, port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
try:
server_socket.bind((host, port))
except OSError:
return False
return True
def find_free_tcp_port(host: str = "127.0.0.1") -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.bind((host, 0))
server_socket.listen(1)
return int(server_socket.getsockname()[1])
def parse_cors_origins(raw_origins: str) -> List[str]:
origins = [origin.strip() for origin in raw_origins.split(",") if origin.strip()]
return origins or ["*"]
def clone_cache_payload(payload: Any) -> Any:
if payload is None or isinstance(payload, (str, int, float, bool)):
return payload
if isinstance(payload, list):
return [clone_cache_payload(item) for item in payload]
if isinstance(payload, tuple):
return tuple(clone_cache_payload(item) for item in payload)
if isinstance(payload, dict):
return {clone_cache_payload(key): clone_cache_payload(value) for key, value in payload.items()}
if isinstance(payload, set):
return {clone_cache_payload(item) for item in payload}
try:
return deepcopy(payload)
except Exception:
return payload
def make_json_compatible(value: Any) -> Any:
if isinstance(value, dict):
return {str(key): make_json_compatible(item) for key, item in value.items()}
if isinstance(value, (list, tuple, set)):
return [make_json_compatible(item) for item in value]
if isinstance(value, np.ndarray):
return [make_json_compatible(item) for item in value.tolist()]
if isinstance(value, np.generic):
return make_json_compatible(value.item())
if isinstance(value, pd.Timestamp):
return value.isoformat()
if isinstance(value, float):
if not np.isfinite(value):
return None
return value
return value
def resolve_runtime_paths(
module_file: str,
is_frozen: bool,
bundle_dir: str,
) -> RuntimePaths:
current_dir = os.path.dirname(os.path.abspath(module_file))
if is_frozen:
return RuntimePaths(current_dir=bundle_dir, project_root=bundle_dir)
return RuntimePaths(
current_dir=current_dir,
project_root=os.path.dirname(current_dir),
)