Spaces:
Sleeping
Sleeping
| """Tests for inference.py competition logging format (dispatch domain).""" | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| class TestInferenceFormatCompliance: | |
| TASK_IDS = ["single_incident", "multi_incident", "mass_casualty", "shift_surge"] | |
| def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]: | |
| cmd = [sys.executable, "inference.py"] | |
| merged_env = os.environ.copy() | |
| merged_env.update(env) | |
| merged_env.setdefault("USE_RANDOM", "true") | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| env=merged_env, | |
| cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| ) | |
| return result.returncode, result.stdout, result.stderr | |
| def test_inference_runs_all_tasks(self) -> None: | |
| env = { | |
| "API_BASE_URL": "https://api.example.com", | |
| "MODEL_NAME": "test-model", | |
| "OPENAI_API_KEY": "test-token", | |
| "USE_RANDOM": "true", | |
| } | |
| returncode, stdout, stderr = self._run_inference_capture(env) | |
| assert returncode == 0, f"inference.py failed: {stderr}" | |
| tasks_run = [] | |
| for line in stdout.split("\n"): | |
| if line.startswith("[START]"): | |
| match = re.match(r"\[START\] task=(\S+) env=(\S+) model=(\S+)", line) | |
| assert match | |
| tasks_run.append(match.group(1)) | |
| assert tasks_run == self.TASK_IDS | |
| def test_start_line_format(self) -> None: | |
| env = { | |
| "API_BASE_URL": "https://api.example.com", | |
| "MODEL_NAME": "test-model", | |
| "OPENAI_API_KEY": "test-token", | |
| "USE_RANDOM": "true", | |
| } | |
| _, stdout, _ = self._run_inference_capture(env) | |
| pattern = r"\[START\] task=\S+ env=citywide-dispatch-supervisor model=\S+" | |
| for line in stdout.split("\n"): | |
| if line.startswith("[START]"): | |
| assert re.match(pattern, line) | |
| def test_step_line_error_format(self) -> None: | |
| env = { | |
| "API_BASE_URL": "https://api.example.com", | |
| "MODEL_NAME": "test-model", | |
| "OPENAI_API_KEY": "test-token", | |
| "USE_RANDOM": "true", | |
| } | |
| _, stdout, _ = self._run_inference_capture(env) | |
| valid_errors = {"null", "max_steps_exceeded", "illegal_transition", "step_error"} | |
| for line in stdout.split("\n"): | |
| if not line.startswith("[STEP]"): | |
| continue | |
| match = re.match(r"\[STEP\].+ error=(.+)", line) | |
| assert match | |
| assert match.group(1) in valid_errors | |
| class TestEnvVarValidation: | |
| def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]: | |
| cmd = [sys.executable, "inference.py"] | |
| merged_env = os.environ.copy() | |
| merged_env.update(env) | |
| # Ensure tests are not affected by host environment variables. | |
| if "API_BASE_URL" not in env: | |
| merged_env.pop("API_BASE_URL", None) | |
| if "MODEL_NAME" not in env: | |
| merged_env.pop("MODEL_NAME", None) | |
| if "OPENAI_API_KEY" not in env: | |
| merged_env.pop("OPENAI_API_KEY", None) | |
| if "HF_TOKEN" not in env: | |
| merged_env.pop("HF_TOKEN", None) | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| env=merged_env, | |
| cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| ) | |
| return result.returncode, result.stdout, result.stderr | |
| def test_missing_api_key_when_not_random(self) -> None: | |
| env = { | |
| "USE_RANDOM": "false", | |
| } | |
| returncode, stdout, stderr = self._run_inference_capture(env) | |
| assert returncode != 0 | |
| assert "HF_TOKEN" in (stdout + stderr) | |