Spaces:
Running
Running
| from __future__ import annotations | |
| import sys | |
| import tempfile | |
| import textwrap | |
| import unittest | |
| from pathlib import Path | |
| from backend import server_runtime | |
| class _FakeLogger: | |
| def __init__(self) -> None: | |
| self.warnings: list[tuple[str, tuple[object, ...]]] = [] | |
| def warning(self, msg: str, *args: object, **kwargs: object) -> None: | |
| self.warnings.append((msg, args)) | |
| class ServerRuntimeTests(unittest.TestCase): | |
| def test_bootstrap_runtime_port_sets_huggingface_default(self) -> None: | |
| env = {"SPACE_ID": "demo-space"} | |
| server_runtime.bootstrap_runtime_port(env, huggingface_port=7860) | |
| self.assertEqual(env["PORT"], "7860") | |
| def test_bootstrap_runtime_port_keeps_existing_port(self) -> None: | |
| env = {"SPACE_HOST": "space.example", "PORT": "9000"} | |
| server_runtime.bootstrap_runtime_port(env, huggingface_port=7860) | |
| self.assertEqual(env["PORT"], "9000") | |
| def test_resolve_server_port_keeps_configured_port_when_available(self) -> None: | |
| env = {"PORT": "8000"} | |
| port = server_runtime.resolve_server_port( | |
| "127.0.0.1", | |
| env=env, | |
| port_checker=lambda host, value: host == "127.0.0.1" and value == 8000, | |
| free_port_finder=lambda host: 8123, | |
| ) | |
| self.assertEqual(port, 8000) | |
| self.assertEqual(env["PORT"], "8000") | |
| def test_resolve_server_port_falls_back_when_configured_port_is_busy(self) -> None: | |
| env = {"PORT": "8000"} | |
| logger = _FakeLogger() | |
| port = server_runtime.resolve_server_port( | |
| "127.0.0.1", | |
| env=env, | |
| logger=logger, | |
| port_checker=lambda host, value: False, | |
| free_port_finder=lambda host: 8123, | |
| ) | |
| self.assertEqual(port, 8123) | |
| self.assertEqual(env["PORT"], "8123") | |
| self.assertEqual(len(logger.warnings), 1) | |
| def test_resolve_server_port_uses_huggingface_default_when_missing(self) -> None: | |
| env = {"SPACE_ID": "demo-space"} | |
| port = server_runtime.resolve_server_port( | |
| "0.0.0.0", | |
| env=env, | |
| port_checker=lambda host, value: False, | |
| free_port_finder=lambda host: 8123, | |
| ) | |
| self.assertEqual(port, 7860) | |
| self.assertEqual(env["PORT"], "7860") | |
| def test_load_runtime_env_returns_false_for_missing_file(self) -> None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| env_file = Path(temp_dir) / "missing-aiforecast.env" | |
| self.assertFalse(server_runtime.load_runtime_env(env_file)) | |
| def test_is_admin_token_configured_requires_non_blank_value(self) -> None: | |
| self.assertFalse(server_runtime.is_admin_token_configured({})) | |
| self.assertFalse(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": " "})) | |
| self.assertTrue(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": "runtime-secret"})) | |
| def test_require_admin_token_configured_fails_fast_when_missing(self) -> None: | |
| with self.assertRaisesRegex(RuntimeError, "ADMIN_TOKEN"): | |
| server_runtime.require_admin_token_configured({}) | |
| def test_load_fastapi_app_imports_requested_module_attribute(self) -> None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| project_root = Path(temp_dir) | |
| module_name = "temp_runtime_app_module" | |
| module_path = project_root / f"{module_name}.py" | |
| module_path.write_text( | |
| textwrap.dedent( | |
| """ | |
| app = {"name": "demo-app"} | |
| """ | |
| ).strip(), | |
| encoding="utf-8", | |
| ) | |
| try: | |
| app = server_runtime.load_fastapi_app( | |
| project_root, | |
| module_name=module_name, | |
| attr_name="app", | |
| ) | |
| self.assertEqual(app, {"name": "demo-app"}) | |
| self.assertIn(str(project_root), sys.path) | |
| finally: | |
| sys.modules.pop(module_name, None) | |
| def test_prepare_runtime_environment_sets_thread_env_defaults(self) -> None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| project_root = Path(temp_dir) | |
| env: dict[str, str] = {} | |
| settings = server_runtime.prepare_runtime_environment( | |
| project_root, | |
| env=env, | |
| current_executable=str(project_root / "venv" / "Scripts" / "python.exe"), | |
| ) | |
| self.assertEqual(settings.torch_num_threads, 2) | |
| self.assertEqual(settings.torch_num_interop_threads, 1) | |
| self.assertEqual(settings.omp_num_threads, 2) | |
| self.assertEqual(settings.mkl_num_threads, 2) | |
| self.assertEqual(env["TORCH_NUM_THREADS"], "2") | |
| self.assertEqual(env["TORCH_NUM_INTEROP_THREADS"], "1") | |
| self.assertEqual(env["OMP_NUM_THREADS"], "2") | |
| self.assertEqual(env["MKL_NUM_THREADS"], "2") | |
| def test_prepare_runtime_environment_requires_project_venv_python_by_default(self) -> None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| project_root = Path(temp_dir) | |
| venv_python = project_root / "venv" / "Scripts" / "python.exe" | |
| venv_python.parent.mkdir(parents=True, exist_ok=True) | |
| venv_python.write_text("", encoding="utf-8") | |
| with self.assertRaisesRegex(RuntimeError, "venv"): | |
| server_runtime.prepare_runtime_environment( | |
| project_root, | |
| env={}, | |
| current_executable=r"C:\Python311\python.exe", | |
| ) | |
| def test_prepare_runtime_environment_allows_non_venv_override(self) -> None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| project_root = Path(temp_dir) | |
| venv_python = project_root / "venv" / "Scripts" / "python.exe" | |
| venv_python.parent.mkdir(parents=True, exist_ok=True) | |
| venv_python.write_text("", encoding="utf-8") | |
| settings = server_runtime.prepare_runtime_environment( | |
| project_root, | |
| env={"AIFORECAST_ALLOW_NON_VENV_PYTHON": "true"}, | |
| current_executable=r"C:\Python311\python.exe", | |
| ) | |
| self.assertEqual(settings.torch_num_threads, 2) | |
| if __name__ == "__main__": | |
| unittest.main() | |