Spaces:
Sleeping
Sleeping
File size: 3,877 Bytes
6172160 4904e85 e259b96 4904e85 6172160 4904e85 6172160 4904e85 6172160 4904e85 43f2683 6172160 4904e85 6172160 4904e85 3c855d7 6172160 4904e85 6172160 4904e85 43f2683 6172160 4904e85 3c855d7 4904e85 3c855d7 4904e85 6172160 4904e85 43f2683 6172160 4904e85 3c855d7 4904e85 3c855d7 4904e85 6172160 4904e85 6172160 5b25e42 43f2683 4904e85 15a5d1d 43f2683 15a5d1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | """Tests for inference.py competition logging format (dispatch domain)."""
from __future__ import annotations
import os
import re
import subprocess
import sys
class TestInferenceFormatCompliance:
TASK_IDS = ["single_incident", "multi_incident", "mass_casualty", "shift_surge"]
def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
cmd = [sys.executable, "inference.py"]
merged_env = os.environ.copy()
merged_env.update(env)
merged_env.setdefault("USE_RANDOM", "true")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env=merged_env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
return result.returncode, result.stdout, result.stderr
def test_inference_runs_all_tasks(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
returncode, stdout, stderr = self._run_inference_capture(env)
assert returncode == 0, f"inference.py failed: {stderr}"
tasks_run = []
for line in stdout.split("\n"):
if line.startswith("[START]"):
match = re.match(r"\[START\] task=(\S+) env=(\S+) model=(\S+)", line)
assert match
tasks_run.append(match.group(1))
assert tasks_run == self.TASK_IDS
def test_start_line_format(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
_, stdout, _ = self._run_inference_capture(env)
pattern = r"\[START\] task=\S+ env=citywide-dispatch-supervisor model=\S+"
for line in stdout.split("\n"):
if line.startswith("[START]"):
assert re.match(pattern, line)
def test_step_line_error_format(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
_, stdout, _ = self._run_inference_capture(env)
valid_errors = {"null", "max_steps_exceeded", "illegal_transition", "step_error"}
for line in stdout.split("\n"):
if not line.startswith("[STEP]"):
continue
match = re.match(r"\[STEP\].+ error=(.+)", line)
assert match
assert match.group(1) in valid_errors
class TestEnvVarValidation:
def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
cmd = [sys.executable, "inference.py"]
merged_env = os.environ.copy()
merged_env.update(env)
# Ensure tests are not affected by host environment variables.
if "API_BASE_URL" not in env:
merged_env.pop("API_BASE_URL", None)
if "MODEL_NAME" not in env:
merged_env.pop("MODEL_NAME", None)
if "OPENAI_API_KEY" not in env:
merged_env.pop("OPENAI_API_KEY", None)
if "HF_TOKEN" not in env:
merged_env.pop("HF_TOKEN", None)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env=merged_env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
return result.returncode, result.stdout, result.stderr
def test_missing_api_key_when_not_random(self) -> None:
env = {
"USE_RANDOM": "false",
}
returncode, stdout, stderr = self._run_inference_capture(env)
assert returncode != 0
assert "HF_TOKEN" in (stdout + stderr)
|