File size: 6,354 Bytes
9734b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eec8c3
 
 
 
 
 
 
 
 
9734b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4106e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9734b71
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from __future__ import annotations

import sys
import tempfile
import textwrap
import unittest
from pathlib import Path

from backend import server_runtime


class _FakeLogger:
    def __init__(self) -> None:
        self.warnings: list[tuple[str, tuple[object, ...]]] = []

    def warning(self, msg: str, *args: object, **kwargs: object) -> None:
        self.warnings.append((msg, args))


class ServerRuntimeTests(unittest.TestCase):
    def test_bootstrap_runtime_port_sets_huggingface_default(self) -> None:
        env = {"SPACE_ID": "demo-space"}

        server_runtime.bootstrap_runtime_port(env, huggingface_port=7860)

        self.assertEqual(env["PORT"], "7860")

    def test_bootstrap_runtime_port_keeps_existing_port(self) -> None:
        env = {"SPACE_HOST": "space.example", "PORT": "9000"}

        server_runtime.bootstrap_runtime_port(env, huggingface_port=7860)

        self.assertEqual(env["PORT"], "9000")

    def test_resolve_server_port_keeps_configured_port_when_available(self) -> None:
        env = {"PORT": "8000"}
        port = server_runtime.resolve_server_port(
            "127.0.0.1",
            env=env,
            port_checker=lambda host, value: host == "127.0.0.1" and value == 8000,
            free_port_finder=lambda host: 8123,
        )
        self.assertEqual(port, 8000)
        self.assertEqual(env["PORT"], "8000")

    def test_resolve_server_port_falls_back_when_configured_port_is_busy(self) -> None:
        env = {"PORT": "8000"}
        logger = _FakeLogger()

        port = server_runtime.resolve_server_port(
            "127.0.0.1",
            env=env,
            logger=logger,
            port_checker=lambda host, value: False,
            free_port_finder=lambda host: 8123,
        )

        self.assertEqual(port, 8123)
        self.assertEqual(env["PORT"], "8123")
        self.assertEqual(len(logger.warnings), 1)

    def test_resolve_server_port_uses_huggingface_default_when_missing(self) -> None:
        env = {"SPACE_ID": "demo-space"}

        port = server_runtime.resolve_server_port(
            "0.0.0.0",
            env=env,
            port_checker=lambda host, value: False,
            free_port_finder=lambda host: 8123,
        )

        self.assertEqual(port, 7860)
        self.assertEqual(env["PORT"], "7860")

    def test_load_runtime_env_returns_false_for_missing_file(self) -> None:
        with tempfile.TemporaryDirectory() as temp_dir:
            env_file = Path(temp_dir) / "missing-aiforecast.env"
            self.assertFalse(server_runtime.load_runtime_env(env_file))

    def test_is_admin_token_configured_requires_non_blank_value(self) -> None:
        self.assertFalse(server_runtime.is_admin_token_configured({}))
        self.assertFalse(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": "   "}))
        self.assertTrue(server_runtime.is_admin_token_configured({"ADMIN_TOKEN": "runtime-secret"}))

    def test_require_admin_token_configured_fails_fast_when_missing(self) -> None:
        with self.assertRaisesRegex(RuntimeError, "ADMIN_TOKEN"):
            server_runtime.require_admin_token_configured({})

    def test_load_fastapi_app_imports_requested_module_attribute(self) -> None:
        with tempfile.TemporaryDirectory() as temp_dir:
            project_root = Path(temp_dir)
            module_name = "temp_runtime_app_module"
            module_path = project_root / f"{module_name}.py"
            module_path.write_text(
                textwrap.dedent(
                    """
                    app = {"name": "demo-app"}
                    """
                ).strip(),
                encoding="utf-8",
            )

            try:
                app = server_runtime.load_fastapi_app(
                    project_root,
                    module_name=module_name,
                    attr_name="app",
                )
                self.assertEqual(app, {"name": "demo-app"})
                self.assertIn(str(project_root), sys.path)
            finally:
                sys.modules.pop(module_name, None)

    def test_prepare_runtime_environment_sets_thread_env_defaults(self) -> None:
        with tempfile.TemporaryDirectory() as temp_dir:
            project_root = Path(temp_dir)
            env: dict[str, str] = {}

            settings = server_runtime.prepare_runtime_environment(
                project_root,
                env=env,
                current_executable=str(project_root / "venv" / "Scripts" / "python.exe"),
            )

        self.assertEqual(settings.torch_num_threads, 2)
        self.assertEqual(settings.torch_num_interop_threads, 1)
        self.assertEqual(settings.omp_num_threads, 2)
        self.assertEqual(settings.mkl_num_threads, 2)
        self.assertEqual(env["TORCH_NUM_THREADS"], "2")
        self.assertEqual(env["TORCH_NUM_INTEROP_THREADS"], "1")
        self.assertEqual(env["OMP_NUM_THREADS"], "2")
        self.assertEqual(env["MKL_NUM_THREADS"], "2")

    def test_prepare_runtime_environment_requires_project_venv_python_by_default(self) -> None:
        with tempfile.TemporaryDirectory() as temp_dir:
            project_root = Path(temp_dir)
            venv_python = project_root / "venv" / "Scripts" / "python.exe"
            venv_python.parent.mkdir(parents=True, exist_ok=True)
            venv_python.write_text("", encoding="utf-8")

            with self.assertRaisesRegex(RuntimeError, "venv"):
                server_runtime.prepare_runtime_environment(
                    project_root,
                    env={},
                    current_executable=r"C:\Python311\python.exe",
                )

    def test_prepare_runtime_environment_allows_non_venv_override(self) -> None:
        with tempfile.TemporaryDirectory() as temp_dir:
            project_root = Path(temp_dir)
            venv_python = project_root / "venv" / "Scripts" / "python.exe"
            venv_python.parent.mkdir(parents=True, exist_ok=True)
            venv_python.write_text("", encoding="utf-8")

            settings = server_runtime.prepare_runtime_environment(
                project_root,
                env={"AIFORECAST_ALLOW_NON_VENV_PYTHON": "true"},
                current_executable=r"C:\Python311\python.exe",
            )

        self.assertEqual(settings.torch_num_threads, 2)


if __name__ == "__main__":
    unittest.main()