from __future__ import annotations import sys import tempfile import textwrap import unittest from pathlib import Path from backend import server_runtime class _FakeLogger: def __init__(self) -> None: self.warnings: list[tuple[str, tuple[object, ...]]] = [] def warning(self, msg: str, *args: object, **kwargs: object) -> None: self.warnings.append((msg, args)) class ServerRuntimeTests(unittest.TestCase): def test_bootstrap_runtime_port_sets_huggingface_default(self) -> None: env = {"SPACE_ID": "demo-space"} server_runtime.bootstrap_runtime_port(env, huggingface_port=7860) self.assertEqual(env["PORT"], "7860") def test_bootstrap_runtime_port_keeps_existing_port(self) -> None: env = {"SPACE_HOST": "space.example", "PORT": "9000"} server_runtime.bootstrap_runtime_port(env, huggingface_port=7860) self.assertEqual(env["PORT"], "9000") def test_resolve_server_port_keeps_configured_port_when_available(self) -> None: env = {"PORT": "8000"} port = server_runtime.resolve_server_port( "127.0.0.1", env=env, port_checker=lambda host, value: host == "127.0.0.1" and value == 8000, free_port_finder=lambda host: 8123, ) self.assertEqual(port, 8000) self.assertEqual(env["PORT"], "8000") def test_resolve_server_port_falls_back_when_configured_port_is_busy(self) -> None: env = {"PORT": "8000"} logger = _FakeLogger() port = server_runtime.resolve_server_port( "127.0.0.1", env=env, logger=logger, port_checker=lambda host, value: False, free_port_finder=lambda host: 8123, ) self.assertEqual(port, 8123) self.assertEqual(env["PORT"], "8123") self.assertEqual(len(logger.warnings), 1) def test_resolve_server_port_uses_huggingface_default_when_missing(self) -> None: env = {"SPACE_ID": "demo-space"} port = server_runtime.resolve_server_port( "0.0.0.0", env=env, port_checker=lambda host, value: False, free_port_finder=lambda host: 8123, ) self.assertEqual(port, 7860) self.assertEqual(env["PORT"], "7860") def test_load_runtime_env_returns_false_for_missing_file(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: env_file = Path(temp_dir) / "missing-aiforecast.env" self.assertFalse(server_runtime.load_runtime_env(env_file)) def test_is_admin_token_configured_requires_non_blank_value(self) -> None: self.assertFalse(server_runtime.is_admin_token_configured({})) self.assertFalse(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": " "})) self.assertTrue(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": "runtime-secret"})) def test_require_admin_token_configured_fails_fast_when_missing(self) -> None: with self.assertRaisesRegex(RuntimeError, "ADMIN_TOKEN"): server_runtime.require_admin_token_configured({}) def test_load_fastapi_app_imports_requested_module_attribute(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: project_root = Path(temp_dir) module_name = "temp_runtime_app_module" module_path = project_root / f"{module_name}.py" module_path.write_text( textwrap.dedent( """ app = {"name": "demo-app"} """ ).strip(), encoding="utf-8", ) try: app = server_runtime.load_fastapi_app( project_root, module_name=module_name, attr_name="app", ) self.assertEqual(app, {"name": "demo-app"}) self.assertIn(str(project_root), sys.path) finally: sys.modules.pop(module_name, None) def test_prepare_runtime_environment_sets_thread_env_defaults(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: project_root = Path(temp_dir) env: dict[str, str] = {} settings = server_runtime.prepare_runtime_environment( project_root, env=env, current_executable=str(project_root / "venv" / "Scripts" / "python.exe"), ) self.assertEqual(settings.torch_num_threads, 2) self.assertEqual(settings.torch_num_interop_threads, 1) self.assertEqual(settings.omp_num_threads, 2) self.assertEqual(settings.mkl_num_threads, 2) self.assertEqual(env["TORCH_NUM_THREADS"], "2") self.assertEqual(env["TORCH_NUM_INTEROP_THREADS"], "1") self.assertEqual(env["OMP_NUM_THREADS"], "2") self.assertEqual(env["MKL_NUM_THREADS"], "2") def test_prepare_runtime_environment_requires_project_venv_python_by_default(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: project_root = Path(temp_dir) venv_python = project_root / "venv" / "Scripts" / "python.exe" venv_python.parent.mkdir(parents=True, exist_ok=True) venv_python.write_text("", encoding="utf-8") with self.assertRaisesRegex(RuntimeError, "venv"): server_runtime.prepare_runtime_environment( project_root, env={}, current_executable=r"C:\Python311\python.exe", ) def test_prepare_runtime_environment_allows_non_venv_override(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: project_root = Path(temp_dir) venv_python = project_root / "venv" / "Scripts" / "python.exe" venv_python.parent.mkdir(parents=True, exist_ok=True) venv_python.write_text("", encoding="utf-8") settings = server_runtime.prepare_runtime_environment( project_root, env={"AIFORECAST_ALLOW_NON_VENV_PYTHON": "true"}, current_executable=r"C:\Python311\python.exe", ) self.assertEqual(settings.torch_num_threads, 2) if __name__ == "__main__": unittest.main()