| from token_holdem.agents import AgentProfile, ROSTER |
| import pytest |
|
|
| from token_holdem.model_runtime import ( |
| DEFAULT_MODAL_MODEL_NAMES, |
| DEFAULT_MODAL_TIMEOUT_SECONDS, |
| LocalRuntime, |
| ModalRuntime, |
| ModelRuntimeUnavailable, |
| SUPPORTED_TRANSFORMERS_MODELS, |
| TransformersRuntime, |
| apply_poker_sanity_guard, |
| configured_modal_model_names, |
| finalize_table_talk, |
| first_valid_decision, |
| get_model_runtime, |
| modal_worker_class_name, |
| parse_model_json, |
| requires_gguf_runtime, |
| safe_action, |
| sanitize_table_talk, |
| validate_decision, |
| ) |
|
|
|
|
| LEGAL = { |
| "actions": ["fold", "call", "raise", "all_in"], |
| "to_call": 20, |
| "raise_presets": {"min": 40, "half_pot": 80, "pot": 140, "all_in": 500}, |
| } |
|
|
|
|
| def summary(**overrides): |
| value = { |
| "hand_no": 1, |
| "street": "preflop", |
| "hole_cards": ["As", "Kd"], |
| "community_cards": [], |
| "stack": 1000, |
| "pot": 30, |
| "legal": LEGAL, |
| "history": [], |
| "recent_chats": [], |
| "seed": 123, |
| "session_id": "test-session", |
| "hand_id": "test-hand", |
| "orbit_id": "test-orbit", |
| } |
| value.update(overrides) |
| return value |
|
|
|
|
| class FakeModalCall: |
| def __init__(self, response): |
| self.response = response |
| self.timeout = None |
|
|
| def get(self, timeout=None): |
| self.timeout = timeout |
| return self.response |
|
|
|
|
| class FakeModalFunction: |
| def __init__(self, response=None, error=None): |
| self.response = response |
| self.error = error |
| self.calls = [] |
| self.last_call = None |
|
|
| def spawn(self, *args): |
| self.calls.append(args) |
| if self.error: |
| raise self.error |
| self.last_call = FakeModalCall(self.response) |
| return self.last_call |
|
|
|
|
| class FakeRemoteMethod: |
| def __init__(self, worker, response=None, error=None): |
| self.worker = worker |
| self.response = response |
| self.error = error |
|
|
| def spawn(self, *args): |
| self.worker.calls.append(args) |
| if self.error: |
| raise self.error |
| self.worker.last_call = FakeModalCall(self.response) |
| return self.worker.last_call |
|
|
|
|
| class FakeModalWorker: |
| def __init__(self, model_id=None, response=None, error=None): |
| self.model_id = model_id |
| self.calls = [] |
| self.last_call = None |
| self.decide = FakeRemoteMethod(self, response=response, error=error) |
|
|
|
|
| class FakeModalWorkerClass: |
| def __init__(self, response=None, error=None): |
| self.response = response |
| self.error = error |
| self.instances = [] |
|
|
| def __call__(self, *, model_id): |
| worker = FakeModalWorker(model_id=model_id, response=self.response, error=self.error) |
| self.instances.append(worker) |
| return worker |
|
|
|
|
| def test_parse_model_json_extracts_object_from_extra_text(): |
| parsed = parse_model_json('thinking... {"action":"call","amount":0,"table_talk":"cheers"} done') |
|
|
| assert parsed == {"action": "call", "amount": 0, "table_talk": "cheers"} |
|
|
|
|
| def test_first_valid_decision_skips_invalid_schema_object(): |
| text = '{"action":"fold|check|call|raise|all_in","amount":0,"table_talk":"schema"}\n{"action":"check","amount":0,"table_talk":"The candle can wait."}' |
| legal = {"actions": ["check", "raise", "all_in"], "to_call": 0, "raise_presets": {"min": 20, "half_pot": 40, "pot": 80, "all_in": 980}} |
|
|
| assert first_valid_decision(text, legal) == {"action": "check", "amount": 0, "reasoning_hint": "", "table_talk": "The candle can wait"} |
|
|
|
|
| def test_validate_decision_clamps_raise_to_preset(): |
| decision = validate_decision({"action": "raise", "amount": 91, "table_talk": "I raise by candlelight."}, LEGAL) |
|
|
| assert decision == {"action": "raise", "amount": 80, "reasoning_hint": "", "table_talk": "The cards clink like tiny mugs."} |
|
|
|
|
| def test_sanitize_table_talk_removes_numbers_and_action_claims(): |
| talk = sanitize_table_talk("I'm raising 955 to make the pot bigger. Let's see...") |
|
|
| assert "955" not in talk |
| assert "raising" not in talk.lower() |
| assert talk == "to make the pot bigger" |
|
|
|
|
| def test_sanitize_table_talk_strips_outer_quotes(): |
| assert sanitize_table_talk('"That is a fine bluff, lad.') == "That is a fine bluff, lad" |
|
|
|
|
| def test_sanitize_table_talk_rejects_prompt_leakage(): |
| assert sanitize_table_talk('No markdown. Your first sentence must be: "take all the chips"') == "" |
| assert sanitize_table_talk("Don't mention any other players. Now think carefully.") == "" |
| assert sanitize_table_talk("short cozy poker banter") == "" |
| assert sanitize_table_talk("If you're not already, let me know in this response what your name is and what color hat you're wearing") == "" |
|
|
|
|
| def test_finalize_table_talk_avoids_recent_repetition(): |
| profile = AgentProfile("Gemma", "local", "warm", 0.3, 0.1, ()) |
| state = {"seed": 1, "street": "preflop", "history": [], "recent_chats": ["Gemma: I am curious enough to stay by the fire."]} |
|
|
| talk = finalize_table_talk(profile, "call", "I am curious enough to stay by the fire.", state) |
|
|
| assert talk != "I am curious enough to stay by the fire." |
| assert len(talk.split()) >= 3 |
|
|
|
|
| def test_validate_decision_rejects_illegal_action(): |
| assert validate_decision({"action": "dance", "amount": 0}, LEGAL) is None |
|
|
|
|
| def test_safe_action_prefers_call_when_check_unavailable(): |
| assert safe_action(LEGAL)["action"] == "call" |
|
|
|
|
| def test_poker_sanity_guard_folds_weak_hand_to_huge_all_in(): |
| decision = {"action": "call", "amount": 0, "table_talk": "I call."} |
| summary = { |
| "hole_cards": ["Kd", "5c"], |
| "community_cards": ["Qs", "Ts", "8s", "6s"], |
| "pot": 1060, |
| "legal": {"actions": ["fold", "call", "all_in"], "to_call": 980, "raise_presets": {"all_in": 980}}, |
| } |
|
|
| assert apply_poker_sanity_guard(decision, summary)["action"] == "fold" |
|
|
|
|
| def test_poker_sanity_guard_allows_made_flush_call(): |
| decision = {"action": "call", "amount": 0, "table_talk": "I call."} |
| summary = { |
| "hole_cards": ["3s", "3d"], |
| "community_cards": ["Qs", "Ts", "8s", "6s"], |
| "pot": 1060, |
| "legal": {"actions": ["fold", "call", "all_in"], "to_call": 980, "raise_presets": {"all_in": 980}}, |
| } |
|
|
| assert apply_poker_sanity_guard(decision, summary)["action"] == "call" |
|
|
|
|
| def test_supported_model_mapping_uses_roster_model_ids(): |
| roster_model_ids = {profile.name: profile.model_id for profile in ROSTER} |
|
|
| assert SUPPORTED_TRANSFORMERS_MODELS == roster_model_ids |
| assert len(set(SUPPORTED_TRANSFORMERS_MODELS.values())) == len(SUPPORTED_TRANSFORMERS_MODELS) |
|
|
|
|
| def test_configured_modal_model_names_default_includes_cohere_command(monkeypatch): |
| monkeypatch.delenv("TOKEN_HOLDEM_MODAL_MODEL_NAMES", raising=False) |
|
|
| names = configured_modal_model_names() |
|
|
| assert names == DEFAULT_MODAL_MODEL_NAMES |
| assert "Cohere Command R7B" in names |
| assert "Gemma" in names |
|
|
|
|
| def test_configured_modal_model_names_all_explicitly_includes_cohere(): |
| names = configured_modal_model_names("all") |
|
|
| assert names == {profile.name for profile in ROSTER} |
| assert "Cohere Command R7B" in names |
|
|
|
|
| def test_modal_worker_class_name_routes_by_runtime_family(): |
| assert modal_worker_class_name("nvidia/NVIDIA-Nemotron-3-Nano-4B-GGUF") == "GgufModelWorker" |
| assert modal_worker_class_name("unsloth/North-Mini-Code-1.0-GGUF") == "GgufModelWorker" |
| assert modal_worker_class_name("Qwen/Qwen3-0.6B") == "CausalModelWorker" |
| assert modal_worker_class_name("mistralai/Mistral-7B-Instruct-v0.2") == "CausalModelWorker" |
| assert modal_worker_class_name("CohereLabs/c4ai-command-r7b-12-2024") == "CausalModelWorker" |
| assert modal_worker_class_name("google/gemma-4-12B-it") == "MultimodalModelWorker" |
| assert modal_worker_class_name("openai/gpt-oss-20b") == "HeavyCausalModelWorker" |
|
|
|
|
| def test_transformers_runtime_unsupported_model_fails_without_dev_fallback(): |
| runtime = TransformersRuntime() |
| profile = AgentProfile("Unknown Seat", "local/unknown", "mysterious", 0.5, 0.1, ("A mysterious chip appears.",)) |
|
|
| with pytest.raises(ModelRuntimeUnavailable, match="No local Transformers mapping"): |
| runtime.decide(profile, summary()) |
|
|
|
|
| def test_transformers_runtime_dev_fallback_is_explicit(): |
| runtime = TransformersRuntime(allow_fallback=True) |
| profile = AgentProfile("Unknown Seat", "local/unknown", "mysterious", 0.5, 0.1, ("A mysterious chip appears.",)) |
|
|
| result = runtime.decide(profile, summary()) |
|
|
| assert result.source == "fallback" |
| assert result.decision["action"] in LEGAL["actions"] |
| assert result.decision["table_talk"] not in profile.talk |
| assert "No local Transformers mapping" in result.status |
|
|
|
|
| def test_transformers_runtime_gguf_model_fails_without_dev_fallback(): |
| runtime = TransformersRuntime() |
| profile = next(profile for profile in ROSTER if requires_gguf_runtime(profile.model_id)) |
|
|
| with pytest.raises(ModelRuntimeUnavailable, match="local runtime cannot call it"): |
| runtime.decide(profile, summary()) |
|
|
|
|
| def test_transformers_runtime_gguf_model_dev_fallback_is_explicit(): |
| runtime = TransformersRuntime(allow_fallback=True) |
| profile = next(profile for profile in ROSTER if requires_gguf_runtime(profile.model_id)) |
|
|
| result = runtime.decide(profile, summary()) |
|
|
| assert result.source == "fallback" |
| assert result.decision["action"] in LEGAL["actions"] |
| assert "local runtime cannot call it" in result.status |
|
|
|
|
| def test_transformers_runtime_alias_keeps_local_runtime_compatibility(): |
| assert TransformersRuntime is LocalRuntime |
|
|
|
|
| def test_explicit_deterministic_bot_env_selects_dev_runtime(monkeypatch): |
| monkeypatch.setenv("TOKEN_HOLDEM_ALLOW_DETERMINISTIC_BOTS", "1") |
| monkeypatch.delenv("USE_MODAL_INFERENCE", raising=False) |
| profile = next(profile for profile in ROSTER if profile.name == "Gemma") |
|
|
| result = get_model_runtime().decide(profile, summary()) |
|
|
| assert result.source == "deterministic_dev" |
| assert result.decision["action"] in LEGAL["actions"] |
|
|
|
|
| def test_modal_runtime_success_uses_structured_remote_decision(): |
| worker_class = FakeModalWorkerClass( |
| { |
| "action": "call", |
| "bet_amount": 0, |
| "explanation": "priced in", |
| "commentary": "The candlelight keeps me curious.", |
| "raw_model_output": '{"action":"call","amount":0}', |
| "error": None, |
| } |
| ) |
| runtime = ModalRuntime(enabled_model_names={"Gemma"}, remote_workers={"MultimodalModelWorker": worker_class}, timeout_seconds=3) |
| profile = next(profile for profile in ROSTER if profile.name == "Gemma") |
|
|
| result = runtime.decide(profile, summary()) |
|
|
| worker = worker_class.instances[0] |
| assert result.source == "modal_model" |
| assert result.status == profile.model_id |
| assert result.decision["action"] == "call" |
| assert result.decision["table_talk"] == "The candlelight keeps me curious" |
| assert worker.model_id == profile.model_id |
| assert worker.last_call.timeout == 3 |
| assert worker.calls[0][1:4] == (profile.name, profile.persona, LEGAL) |
|
|
|
|
| def test_modal_runtime_failure_raises_instead_of_falling_back(): |
| worker_class = FakeModalWorkerClass(error=TimeoutError("modal timed out")) |
| runtime = ModalRuntime(enabled_model_names={"Gemma"}, remote_workers={"MultimodalModelWorker": worker_class}) |
| profile = next(profile for profile in ROSTER if profile.name == "Gemma") |
|
|
| with pytest.raises(ModelRuntimeUnavailable, match="Modal inference unavailable"): |
| runtime.decide(profile, summary()) |
|
|
|
|
| def test_modal_runtime_remote_error_raises_instead_of_falling_back(): |
| worker_class = FakeModalWorkerClass( |
| { |
| "action": None, |
| "bet_amount": None, |
| "explanation": "", |
| "commentary": "", |
| "raw_model_output": "", |
| "error": "model did not return valid decision JSON", |
| } |
| ) |
| runtime = ModalRuntime(enabled_model_names={"Gemma"}, remote_workers={"MultimodalModelWorker": worker_class}) |
| profile = next(profile for profile in ROSTER if profile.name == "Gemma") |
|
|
| with pytest.raises(ModelRuntimeUnavailable, match="Modal inference returned an error"): |
| runtime.decide(profile, summary()) |
|
|
|
|
| def test_modal_runtime_disabled_model_raises(): |
| runtime = ModalRuntime(enabled_model_names={"Gemma"}, remote_function=FakeModalFunction({})) |
| profile = next(profile for profile in ROSTER if profile.name == "Qwen") |
|
|
| with pytest.raises(ModelRuntimeUnavailable, match="disabled"): |
| runtime.decide(profile, summary()) |
|
|
|
|
| def test_modal_runtime_all_enabled_models_spawn_remote_calls(): |
| response = { |
| "action": "call", |
| "bet_amount": 0, |
| "explanation": "priced in", |
| "commentary": "The candlelight keeps me curious.", |
| "raw_model_output": '{"action":"call","amount":0}', |
| "error": None, |
| } |
| workers = { |
| "GgufModelWorker": FakeModalWorkerClass(response), |
| "MultimodalModelWorker": FakeModalWorkerClass(response), |
| "CausalModelWorker": FakeModalWorkerClass(response), |
| "HeavyCausalModelWorker": FakeModalWorkerClass(response), |
| } |
| runtime = ModalRuntime(enabled_model_names={profile.name for profile in ROSTER}, remote_workers=workers) |
|
|
| for profile in ROSTER: |
| result = runtime.decide(profile, summary()) |
| assert result.source == "modal_model" |
|
|
| called_model_ids = [ |
| instance.model_id |
| for worker_class in workers.values() |
| for instance in worker_class.instances |
| ] |
| assert sorted(called_model_ids) == sorted(profile.model_id for profile in ROSTER) |
|
|
|
|
| def test_modal_runtime_default_timeout_matches_modal_worker_timeout(): |
| worker_class = FakeModalWorkerClass( |
| { |
| "action": "call", |
| "bet_amount": 0, |
| "explanation": "priced in", |
| "commentary": "The candlelight keeps me curious.", |
| "raw_model_output": '{"action":"call","amount":0}', |
| "error": None, |
| } |
| ) |
| runtime = ModalRuntime(enabled_model_names={"Gemma"}, remote_workers={"MultimodalModelWorker": worker_class}) |
| profile = next(profile for profile in ROSTER if profile.name == "Gemma") |
|
|
| runtime.decide(profile, summary()) |
|
|
| assert worker_class.instances[0].last_call.timeout == DEFAULT_MODAL_TIMEOUT_SECONDS |
|
|