File size: 3,458 Bytes
42d0b80
 
 
 
 
 
 
 
4ddf12d
42d0b80
 
 
 
 
 
4ddf12d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42d0b80
 
 
 
 
 
 
 
 
 
4ddf12d
 
42d0b80
4ddf12d
42d0b80
 
 
 
 
4ddf12d
 
42d0b80
4ddf12d
42d0b80
 
4ddf12d
42d0b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import unittest
from unittest.mock import patch

import inference


class InferenceContractTests(unittest.TestCase):
    def test_validate_submission_env_requires_hf_token(self) -> None:
        with patch.dict("os.environ", {}, clear=True):
            with self.assertRaises(RuntimeError) as ctx:
                inference.validate_submission_env()

        self.assertIn("HF_TOKEN", str(ctx.exception))

    def test_validate_submission_env_uses_defaults_and_openai_client(self) -> None:
        class DummyOpenAI:
            last_kwargs = None

            def __init__(self, **kwargs):
                DummyOpenAI.last_kwargs = kwargs

        with patch.object(inference, "OpenAI", DummyOpenAI):
            with patch.dict(
                "os.environ",
                {"HF_TOKEN": "hf_test_token"},
                clear=True,
            ):
                config = inference.validate_submission_env()

        self.assertEqual(config["API_BASE_URL"], "https://generativelanguage.googleapis.com/v1beta/openai/")
        self.assertEqual(config["MODEL_NAME"], "gemini-2.5-flash")
        self.assertEqual(config["HF_TOKEN"], "hf_test_token")
        self.assertIsNone(config["LOCAL_IMAGE_NAME"])
        self.assertEqual(
            DummyOpenAI.last_kwargs,
            {"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", "api_key": "hf_test_token"},
        )

    def test_validate_submission_env_allows_explicit_overrides_and_local_image(self) -> None:
        class DummyOpenAI:
            last_kwargs = None

            def __init__(self, **kwargs):
                DummyOpenAI.last_kwargs = kwargs

        with patch.object(inference, "OpenAI", DummyOpenAI):
            with patch.dict(
                "os.environ",
                {
                    "API_BASE_URL": "https://example.invalid/v1",
                    "MODEL_NAME": "test-model",
                    "HF_TOKEN": "hf_test_token",
                    "LOCAL_IMAGE_NAME": "local-env-image",
                },
                clear=True,
            ):
                config = inference.validate_submission_env()

        self.assertEqual(config["API_BASE_URL"], "https://example.invalid/v1")
        self.assertEqual(config["MODEL_NAME"], "test-model")
        self.assertEqual(config["HF_TOKEN"], "hf_test_token")
        self.assertEqual(config["LOCAL_IMAGE_NAME"], "local-env-image")
        self.assertEqual(
            DummyOpenAI.last_kwargs,
            {"base_url": "https://example.invalid/v1", "api_key": "hf_test_token"},
        )

    def test_main_forces_submission_safe_defaults(self) -> None:
        called = {"count": 0}

        def fake_baseline_main() -> None:
            called["count"] += 1

        with patch.object(inference, "baseline_main", fake_baseline_main):
            with patch.object(inference, "validate_submission_env", return_value={"ok": True}):
                with patch.dict("os.environ", {}, clear=True):
                    inference.main()
                    self.assertEqual(os.environ["BASELINE_AGENT"], "openai")
                    self.assertEqual(os.environ["BASELINE_OUTPUT_MODE"], "submission")
                    self.assertEqual(os.environ["EPISODES_PER_DIFFICULTY"], "1")
                    self.assertEqual(os.environ["BASELINE_RESULTS_PATH"], "baseline/submission_results.json")

        self.assertEqual(called["count"], 1)


if __name__ == "__main__":
    unittest.main()