File size: 3,877 Bytes
6172160
4904e85
 
 
 
 
 
 
 
 
 
e259b96
4904e85
6172160
4904e85
 
6172160
 
4904e85
 
 
 
 
 
 
 
 
6172160
4904e85
 
 
43f2683
6172160
4904e85
 
 
6172160
4904e85
3c855d7
 
 
 
6172160
4904e85
6172160
4904e85
 
 
43f2683
6172160
4904e85
 
3c855d7
4904e85
3c855d7
 
4904e85
6172160
4904e85
 
 
43f2683
6172160
4904e85
 
3c855d7
4904e85
3c855d7
 
 
 
 
4904e85
 
 
6172160
4904e85
 
6172160
5b25e42
 
 
 
 
 
43f2683
 
 
 
4904e85
 
 
 
 
 
 
 
 
15a5d1d
43f2683
 
 
 
 
15a5d1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Tests for inference.py competition logging format (dispatch domain)."""

from __future__ import annotations

import os
import re
import subprocess
import sys


class TestInferenceFormatCompliance:
    TASK_IDS = ["single_incident", "multi_incident", "mass_casualty", "shift_surge"]

    def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
        cmd = [sys.executable, "inference.py"]
        merged_env = os.environ.copy()
        merged_env.update(env)
        merged_env.setdefault("USE_RANDOM", "true")
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            env=merged_env,
            cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        )
        return result.returncode, result.stdout, result.stderr

    def test_inference_runs_all_tasks(self) -> None:
        env = {
            "API_BASE_URL": "https://api.example.com",
            "MODEL_NAME": "test-model",
            "OPENAI_API_KEY": "test-token",
            "USE_RANDOM": "true",
        }
        returncode, stdout, stderr = self._run_inference_capture(env)
        assert returncode == 0, f"inference.py failed: {stderr}"
        tasks_run = []
        for line in stdout.split("\n"):
            if line.startswith("[START]"):
                match = re.match(r"\[START\] task=(\S+) env=(\S+) model=(\S+)", line)
                assert match
                tasks_run.append(match.group(1))
        assert tasks_run == self.TASK_IDS

    def test_start_line_format(self) -> None:
        env = {
            "API_BASE_URL": "https://api.example.com",
            "MODEL_NAME": "test-model",
            "OPENAI_API_KEY": "test-token",
            "USE_RANDOM": "true",
        }
        _, stdout, _ = self._run_inference_capture(env)
        pattern = r"\[START\] task=\S+ env=citywide-dispatch-supervisor model=\S+"
        for line in stdout.split("\n"):
            if line.startswith("[START]"):
                assert re.match(pattern, line)

    def test_step_line_error_format(self) -> None:
        env = {
            "API_BASE_URL": "https://api.example.com",
            "MODEL_NAME": "test-model",
            "OPENAI_API_KEY": "test-token",
            "USE_RANDOM": "true",
        }
        _, stdout, _ = self._run_inference_capture(env)
        valid_errors = {"null", "max_steps_exceeded", "illegal_transition", "step_error"}
        for line in stdout.split("\n"):
            if not line.startswith("[STEP]"):
                continue
            match = re.match(r"\[STEP\].+ error=(.+)", line)
            assert match
            assert match.group(1) in valid_errors


class TestEnvVarValidation:
    def _run_inference_capture(self, env: dict[str, str]) -> tuple[int, str, str]:
        cmd = [sys.executable, "inference.py"]
        merged_env = os.environ.copy()
        merged_env.update(env)

        # Ensure tests are not affected by host environment variables.
        if "API_BASE_URL" not in env:
            merged_env.pop("API_BASE_URL", None)
        if "MODEL_NAME" not in env:
            merged_env.pop("MODEL_NAME", None)
        if "OPENAI_API_KEY" not in env:
            merged_env.pop("OPENAI_API_KEY", None)
        if "HF_TOKEN" not in env:
            merged_env.pop("HF_TOKEN", None)
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            env=merged_env,
            cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        )
        return result.returncode, result.stdout, result.stderr

    def test_missing_api_key_when_not_random(self) -> None:
        env = {
            "USE_RANDOM": "false",
        }
        returncode, stdout, stderr = self._run_inference_capture(env)
        assert returncode != 0
        assert "HF_TOKEN" in (stdout + stderr)