Spaces:
Running
Running
File size: 6,354 Bytes
9734b71 2eec8c3 9734b71 4106e0f 9734b71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | from __future__ import annotations
import sys
import tempfile
import textwrap
import unittest
from pathlib import Path
from backend import server_runtime
class _FakeLogger:
def __init__(self) -> None:
self.warnings: list[tuple[str, tuple[object, ...]]] = []
def warning(self, msg: str, *args: object, **kwargs: object) -> None:
self.warnings.append((msg, args))
class ServerRuntimeTests(unittest.TestCase):
def test_bootstrap_runtime_port_sets_huggingface_default(self) -> None:
env = {"SPACE_ID": "demo-space"}
server_runtime.bootstrap_runtime_port(env, huggingface_port=7860)
self.assertEqual(env["PORT"], "7860")
def test_bootstrap_runtime_port_keeps_existing_port(self) -> None:
env = {"SPACE_HOST": "space.example", "PORT": "9000"}
server_runtime.bootstrap_runtime_port(env, huggingface_port=7860)
self.assertEqual(env["PORT"], "9000")
def test_resolve_server_port_keeps_configured_port_when_available(self) -> None:
env = {"PORT": "8000"}
port = server_runtime.resolve_server_port(
"127.0.0.1",
env=env,
port_checker=lambda host, value: host == "127.0.0.1" and value == 8000,
free_port_finder=lambda host: 8123,
)
self.assertEqual(port, 8000)
self.assertEqual(env["PORT"], "8000")
def test_resolve_server_port_falls_back_when_configured_port_is_busy(self) -> None:
env = {"PORT": "8000"}
logger = _FakeLogger()
port = server_runtime.resolve_server_port(
"127.0.0.1",
env=env,
logger=logger,
port_checker=lambda host, value: False,
free_port_finder=lambda host: 8123,
)
self.assertEqual(port, 8123)
self.assertEqual(env["PORT"], "8123")
self.assertEqual(len(logger.warnings), 1)
def test_resolve_server_port_uses_huggingface_default_when_missing(self) -> None:
env = {"SPACE_ID": "demo-space"}
port = server_runtime.resolve_server_port(
"0.0.0.0",
env=env,
port_checker=lambda host, value: False,
free_port_finder=lambda host: 8123,
)
self.assertEqual(port, 7860)
self.assertEqual(env["PORT"], "7860")
def test_load_runtime_env_returns_false_for_missing_file(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
env_file = Path(temp_dir) / "missing-aiforecast.env"
self.assertFalse(server_runtime.load_runtime_env(env_file))
def test_is_admin_token_configured_requires_non_blank_value(self) -> None:
self.assertFalse(server_runtime.is_admin_token_configured({}))
self.assertFalse(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": " "}))
self.assertTrue(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": "runtime-secret"}))
def test_require_admin_token_configured_fails_fast_when_missing(self) -> None:
with self.assertRaisesRegex(RuntimeError, "ADMIN_TOKEN"):
server_runtime.require_admin_token_configured({})
def test_load_fastapi_app_imports_requested_module_attribute(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
project_root = Path(temp_dir)
module_name = "temp_runtime_app_module"
module_path = project_root / f"{module_name}.py"
module_path.write_text(
textwrap.dedent(
"""
app = {"name": "demo-app"}
"""
).strip(),
encoding="utf-8",
)
try:
app = server_runtime.load_fastapi_app(
project_root,
module_name=module_name,
attr_name="app",
)
self.assertEqual(app, {"name": "demo-app"})
self.assertIn(str(project_root), sys.path)
finally:
sys.modules.pop(module_name, None)
def test_prepare_runtime_environment_sets_thread_env_defaults(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
project_root = Path(temp_dir)
env: dict[str, str] = {}
settings = server_runtime.prepare_runtime_environment(
project_root,
env=env,
current_executable=str(project_root / "venv" / "Scripts" / "python.exe"),
)
self.assertEqual(settings.torch_num_threads, 2)
self.assertEqual(settings.torch_num_interop_threads, 1)
self.assertEqual(settings.omp_num_threads, 2)
self.assertEqual(settings.mkl_num_threads, 2)
self.assertEqual(env["TORCH_NUM_THREADS"], "2")
self.assertEqual(env["TORCH_NUM_INTEROP_THREADS"], "1")
self.assertEqual(env["OMP_NUM_THREADS"], "2")
self.assertEqual(env["MKL_NUM_THREADS"], "2")
def test_prepare_runtime_environment_requires_project_venv_python_by_default(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
project_root = Path(temp_dir)
venv_python = project_root / "venv" / "Scripts" / "python.exe"
venv_python.parent.mkdir(parents=True, exist_ok=True)
venv_python.write_text("", encoding="utf-8")
with self.assertRaisesRegex(RuntimeError, "venv"):
server_runtime.prepare_runtime_environment(
project_root,
env={},
current_executable=r"C:\Python311\python.exe",
)
def test_prepare_runtime_environment_allows_non_venv_override(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
project_root = Path(temp_dir)
venv_python = project_root / "venv" / "Scripts" / "python.exe"
venv_python.parent.mkdir(parents=True, exist_ok=True)
venv_python.write_text("", encoding="utf-8")
settings = server_runtime.prepare_runtime_environment(
project_root,
env={"AIFORECAST_ALLOW_NON_VENV_PYTHON": "true"},
current_executable=r"C:\Python311\python.exe",
)
self.assertEqual(settings.torch_num_threads, 2)
if __name__ == "__main__":
unittest.main()
|