import os import unittest from unittest.mock import patch import inference class InferenceContractTests(unittest.TestCase): def test_validate_submission_env_requires_hf_token(self) -> None: with patch.dict("os.environ", {}, clear=True): with self.assertRaises(RuntimeError) as ctx: inference.validate_submission_env() self.assertIn("HF_TOKEN", str(ctx.exception)) def test_validate_submission_env_uses_defaults_and_openai_client(self) -> None: class DummyOpenAI: last_kwargs = None def __init__(self, **kwargs): DummyOpenAI.last_kwargs = kwargs with patch.object(inference, "OpenAI", DummyOpenAI): with patch.dict( "os.environ", {"HF_TOKEN": "hf_test_token"}, clear=True, ): config = inference.validate_submission_env() self.assertEqual(config["API_BASE_URL"], "https://generativelanguage.googleapis.com/v1beta/openai/") self.assertEqual(config["MODEL_NAME"], "gemini-2.5-flash") self.assertEqual(config["HF_TOKEN"], "hf_test_token") self.assertIsNone(config["LOCAL_IMAGE_NAME"]) self.assertEqual( DummyOpenAI.last_kwargs, {"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", "api_key": "hf_test_token"}, ) def test_validate_submission_env_allows_explicit_overrides_and_local_image(self) -> None: class DummyOpenAI: last_kwargs = None def __init__(self, **kwargs): DummyOpenAI.last_kwargs = kwargs with patch.object(inference, "OpenAI", DummyOpenAI): with patch.dict( "os.environ", { "API_BASE_URL": "https://example.invalid/v1", "MODEL_NAME": "test-model", "HF_TOKEN": "hf_test_token", "LOCAL_IMAGE_NAME": "local-env-image", }, clear=True, ): config = inference.validate_submission_env() self.assertEqual(config["API_BASE_URL"], "https://example.invalid/v1") self.assertEqual(config["MODEL_NAME"], "test-model") self.assertEqual(config["HF_TOKEN"], "hf_test_token") self.assertEqual(config["LOCAL_IMAGE_NAME"], "local-env-image") self.assertEqual( DummyOpenAI.last_kwargs, {"base_url": "https://example.invalid/v1", "api_key": "hf_test_token"}, ) def test_main_forces_submission_safe_defaults(self) -> None: called = {"count": 0} def fake_baseline_main() -> None: called["count"] += 1 with patch.object(inference, "baseline_main", fake_baseline_main): with patch.object(inference, "validate_submission_env", return_value={"ok": True}): with patch.dict("os.environ", {}, clear=True): inference.main() self.assertEqual(os.environ["BASELINE_AGENT"], "openai") self.assertEqual(os.environ["BASELINE_OUTPUT_MODE"], "submission") self.assertEqual(os.environ["EPISODES_PER_DIFFICULTY"], "1") self.assertEqual(os.environ["BASELINE_RESULTS_PATH"], "baseline/submission_results.json") self.assertEqual(called["count"], 1) if __name__ == "__main__": unittest.main()