911 / tests /test_inference.py
SayedZahur786's picture
fix: revert inference format and make /reset payload optional
3c855d7
"""Tests for inference.py competition logging format (dispatch domain)."""
from __future__ import annotations
import os
import re
import subprocess
import sys
class TestInferenceFormatCompliance:
TASK_IDS = ["single_incident", "multi_incident", "mass_casualty", "shift_surge"]
def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
cmd = [sys.executable, "inference.py"]
merged_env = os.environ.copy()
merged_env.update(env)
merged_env.setdefault("USE_RANDOM", "true")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env=merged_env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
return result.returncode, result.stdout, result.stderr
def test_inference_runs_all_tasks(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
returncode, stdout, stderr = self._run_inference_capture(env)
assert returncode == 0, f"inference.py failed: {stderr}"
tasks_run = []
for line in stdout.split("\n"):
if line.startswith("[START]"):
match = re.match(r"\[START\] task=(\S+) env=(\S+) model=(\S+)", line)
assert match
tasks_run.append(match.group(1))
assert tasks_run == self.TASK_IDS
def test_start_line_format(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
_, stdout, _ = self._run_inference_capture(env)
pattern = r"\[START\] task=\S+ env=citywide-dispatch-supervisor model=\S+"
for line in stdout.split("\n"):
if line.startswith("[START]"):
assert re.match(pattern, line)
def test_step_line_error_format(self) -> None:
env = {
"API_BASE_URL": "https://api.example.com",
"MODEL_NAME": "test-model",
"OPENAI_API_KEY": "test-token",
"USE_RANDOM": "true",
}
_, stdout, _ = self._run_inference_capture(env)
valid_errors = {"null", "max_steps_exceeded", "illegal_transition", "step_error"}
for line in stdout.split("\n"):
if not line.startswith("[STEP]"):
continue
match = re.match(r"\[STEP\].+ error=(.+)", line)
assert match
assert match.group(1) in valid_errors
class TestEnvVarValidation:
def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
cmd = [sys.executable, "inference.py"]
merged_env = os.environ.copy()
merged_env.update(env)
# Ensure tests are not affected by host environment variables.
if "API_BASE_URL" not in env:
merged_env.pop("API_BASE_URL", None)
if "MODEL_NAME" not in env:
merged_env.pop("MODEL_NAME", None)
if "OPENAI_API_KEY" not in env:
merged_env.pop("OPENAI_API_KEY", None)
if "HF_TOKEN" not in env:
merged_env.pop("HF_TOKEN", None)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
env=merged_env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
return result.returncode, result.stdout, result.stderr
def test_missing_api_key_when_not_random(self) -> None:
env = {
"USE_RANDOM": "false",
}
returncode, stdout, stderr = self._run_inference_capture(env)
assert returncode != 0
assert "HF_TOKEN" in (stdout + stderr)