| """Tests for premium model handling in backend/routes/agent.py.""" |
|
|
| import asyncio |
| import sys |
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import pytest |
| from fastapi import HTTPException |
|
|
| _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" |
| if str(_BACKEND_DIR) not in sys.path: |
| sys.path.insert(0, str(_BACKEND_DIR)) |
|
|
| from routes import agent |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def _reset_quota_store(): |
| agent.user_quotas._reset_for_tests() |
| yield |
| agent.user_quotas._reset_for_tests() |
|
|
|
|
| def test_premium_model_predicate_includes_bedrock_claude_and_gpt55_only(): |
| assert agent._is_premium_model("bedrock/us.anthropic.claude-opus-4-6-v1") |
| assert agent._is_premium_model("openai/gpt-5.5") |
| assert not agent._is_premium_model("anthropic/claude-opus-4-6") |
| assert not agent._is_premium_model("moonshotai/Kimi-K2.6") |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_default_premium_session_falls_back_to_free_model(monkeypatch): |
| monkeypatch.setattr( |
| agent.session_manager.config, |
| "model_name", |
| agent.DEFAULT_CLAUDE_MODEL_ID, |
| ) |
|
|
| model = await agent._model_override_for_new_session(None, None) |
|
|
| assert model == agent.DEFAULT_FREE_MODEL_ID |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_default_free_session_keeps_config_default(monkeypatch): |
| monkeypatch.setattr( |
| agent.session_manager.config, |
| "model_name", |
| agent.DEFAULT_FREE_MODEL_ID, |
| ) |
|
|
| model = await agent._model_override_for_new_session(None, None) |
|
|
| assert model is None |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_explicit_premium_session_allowed_for_authenticated_user(): |
| model = await agent._model_override_for_new_session( |
| None, |
| agent.DEFAULT_CLAUDE_MODEL_ID, |
| ) |
|
|
| assert model == agent.DEFAULT_CLAUDE_MODEL_ID |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_switching_to_premium_model_is_allowed_for_authenticated_user( |
| monkeypatch, |
| ): |
| updated = [] |
|
|
| async def fake_check_session_access(session_id, user, request=None): |
| assert session_id == "s1" |
| assert user["user_id"] == "u1" |
| return SimpleNamespace(user_id="u1") |
|
|
| async def fake_update_session_model(session_id, model_id): |
| updated.append((session_id, model_id)) |
|
|
| monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access) |
| monkeypatch.setattr( |
| agent.session_manager, |
| "update_session_model", |
| fake_update_session_model, |
| ) |
|
|
| response = await agent.set_session_model( |
| "s1", |
| {"model": "openai/gpt-5.5"}, |
| request=None, |
| user={"user_id": "u1", "plan": "free"}, |
| ) |
|
|
| assert response == {"session_id": "s1", "model": "openai/gpt-5.5"} |
| assert updated == [("s1", "openai/gpt-5.5")] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_premium_quota_charges_gpt55(monkeypatch): |
| persisted = [] |
|
|
| async def fake_persist_session_snapshot(agent_session): |
| persisted.append(agent_session) |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "persist_session_snapshot", |
| fake_persist_session_snapshot, |
| ) |
|
|
| agent_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
|
|
| await agent._enforce_premium_model_quota( |
| {"user_id": "u1", "plan": "free"}, |
| agent_session, |
| ) |
|
|
| assert agent_session.claude_counted is True |
| assert persisted == [agent_session] |
| assert await agent.user_quotas.get_claude_used_today("u1") == 1 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_free_user_premium_quota_rejects_second_session(monkeypatch): |
| async def fake_persist_session_snapshot(_agent_session): |
| return None |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "persist_session_snapshot", |
| fake_persist_session_snapshot, |
| ) |
|
|
| first_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
| second_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
|
|
| await agent._enforce_premium_model_quota( |
| {"user_id": "free-user", "plan": "free"}, |
| first_session, |
| ) |
| with pytest.raises(HTTPException) as exc_info: |
| await agent._enforce_premium_model_quota( |
| {"user_id": "free-user", "plan": "free"}, |
| second_session, |
| ) |
|
|
| assert exc_info.value.status_code == 429 |
| assert exc_info.value.detail["error"] == "premium_model_daily_cap" |
| assert exc_info.value.detail["plan"] == "free" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_pro_user_uses_pro_premium_quota(monkeypatch): |
| async def fake_persist_session_snapshot(_agent_session): |
| return None |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "persist_session_snapshot", |
| fake_persist_session_snapshot, |
| ) |
|
|
| for index in range(2): |
| agent_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
| await agent._enforce_premium_model_quota( |
| {"user_id": "pro-user", "plan": "pro"}, |
| agent_session, |
| ) |
| assert agent_session.claude_counted is True |
| assert await agent.user_quotas.get_claude_used_today("pro-user") == index + 1 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_org_plan_uses_free_premium_quota(monkeypatch): |
| async def fake_persist_session_snapshot(_agent_session): |
| return None |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "persist_session_snapshot", |
| fake_persist_session_snapshot, |
| ) |
|
|
| first_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
| second_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="openai/gpt-5.5"), |
| ), |
| ) |
|
|
| await agent._enforce_premium_model_quota( |
| {"user_id": "org-user", "plan": "org"}, |
| first_session, |
| ) |
| with pytest.raises(HTTPException) as exc_info: |
| await agent._enforce_premium_model_quota( |
| {"user_id": "org-user", "plan": "org"}, |
| second_session, |
| ) |
|
|
| assert exc_info.value.status_code == 429 |
| assert exc_info.value.detail["plan"] == "org" |
| assert "Upgrade to HF Pro" in exc_info.value.detail["message"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_premium_quota_skips_direct_anthropic(monkeypatch): |
| async def fail_if_persisted(_agent_session): |
| raise AssertionError("direct Anthropic should not consume premium quota") |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "persist_session_snapshot", |
| fail_if_persisted, |
| ) |
|
|
| agent_session = SimpleNamespace( |
| claude_counted=False, |
| session=SimpleNamespace( |
| config=SimpleNamespace(model_name="anthropic/claude-opus-4-6"), |
| ), |
| ) |
|
|
| await agent._enforce_premium_model_quota( |
| {"user_id": "u1", "plan": "free"}, |
| agent_session, |
| ) |
|
|
| assert agent_session.claude_counted is False |
| assert await agent.user_quotas.get_claude_used_today("u1") == 0 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_user_quota_response_uses_premium_fields_only(monkeypatch): |
| async def fake_get_used_today(user_id): |
| assert user_id == "u1" |
| return 2 |
|
|
| monkeypatch.setattr(agent.user_quotas, "get_claude_used_today", fake_get_used_today) |
| monkeypatch.setattr(agent.user_quotas, "daily_cap_for", lambda plan: 5) |
|
|
| response = await agent.get_user_quota({"user_id": "u1", "plan": "pro"}) |
|
|
| assert response == { |
| "plan": "pro", |
| "premium_used_today": 2, |
| "premium_daily_cap": 5, |
| "premium_remaining": 3, |
| } |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_set_session_yolo_calls_manager_with_cap_presence(monkeypatch): |
| async def fake_check_session_access(session_id, user, request=None): |
| assert session_id == "s1" |
| assert user["user_id"] == "u1" |
| return object() |
|
|
| calls = [] |
|
|
| async def fake_update_session_auto_approval(session_id, **kwargs): |
| calls.append((session_id, kwargs)) |
| return { |
| "enabled": kwargs["enabled"], |
| "cost_cap_usd": 7.5, |
| "estimated_spend_usd": 0.0, |
| "remaining_usd": 7.5, |
| } |
|
|
| monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access) |
| monkeypatch.setattr( |
| agent.session_manager, |
| "update_session_auto_approval", |
| fake_update_session_auto_approval, |
| ) |
|
|
| response = await agent.set_session_yolo( |
| "s1", |
| agent.SessionYoloRequest(enabled=True, cost_cap_usd=7.5), |
| {"user_id": "u1"}, |
| ) |
|
|
| assert response["enabled"] is True |
| assert response["remaining_usd"] == 7.5 |
| assert calls == [ |
| ( |
| "s1", |
| { |
| "enabled": True, |
| "cost_cap_usd": 7.5, |
| "cap_provided": True, |
| }, |
| ) |
| ] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_delete_session_access_check_skips_sandbox_preload(monkeypatch): |
| ensure_calls = [] |
| delete_calls = [] |
|
|
| async def fake_ensure_session_loaded(session_id, user_id, **kwargs): |
| ensure_calls.append((session_id, user_id, kwargs)) |
| return SimpleNamespace(user_id=user_id) |
|
|
| async def fake_delete_session(session_id): |
| delete_calls.append(session_id) |
| return True |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "ensure_session_loaded", |
| fake_ensure_session_loaded, |
| ) |
| monkeypatch.setattr(agent.session_manager, "delete_session", fake_delete_session) |
|
|
| response = await agent.delete_session("s1", {"user_id": "u1"}) |
|
|
| assert response == {"status": "deleted", "session_id": "s1"} |
| assert delete_calls == ["s1"] |
| assert ensure_calls[0][2]["preload_sandbox"] is False |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_teardown_session_access_check_skips_sandbox_preload(monkeypatch): |
| ensure_calls = [] |
| teardown_calls = [] |
|
|
| async def fake_ensure_session_loaded(session_id, user_id, **kwargs): |
| ensure_calls.append((session_id, user_id, kwargs)) |
| return SimpleNamespace(user_id=user_id) |
|
|
| async def fake_teardown_sandbox(session_id): |
| teardown_calls.append(session_id) |
| return True |
|
|
| monkeypatch.setattr( |
| agent.session_manager, |
| "ensure_session_loaded", |
| fake_ensure_session_loaded, |
| ) |
| monkeypatch.setattr( |
| agent.session_manager, "teardown_sandbox", fake_teardown_sandbox |
| ) |
|
|
| response = await agent.teardown_session_sandbox("s1", {"user_id": "u1"}) |
| await asyncio.sleep(0) |
|
|
| assert response == {"status": "teardown_requested", "session_id": "s1"} |
| assert teardown_calls == ["s1"] |
| assert ensure_calls[0][2]["preload_sandbox"] is False |
|
|