ml-intern3 / tests /unit /test_agent_model_gating.py
lewtun's picture
lewtun HF Staff
Remove old Claude model support (#283)
d32b803 unverified
"""Tests for premium model handling in backend/routes/agent.py."""
import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
if str(_BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(_BACKEND_DIR))
from routes import agent # noqa: E402
@pytest.fixture(autouse=True)
def _reset_quota_store():
agent.user_quotas._reset_for_tests()
yield
agent.user_quotas._reset_for_tests()
def _premium_session(model: str = agent.DEFAULT_PREMIUM_MODEL_ID):
return SimpleNamespace(
claude_counted=False,
session=SimpleNamespace(
config=SimpleNamespace(model_name=model),
premium_user_billed=False,
),
)
def test_premium_model_predicate_uses_router_ids_only():
assert agent._is_premium_model(agent.DEFAULT_PREMIUM_MODEL_ID)
assert agent._is_premium_model(agent.DEFAULT_GPT_MODEL_ID)
assert agent._is_user_billed(agent.DEFAULT_PREMIUM_MODEL_ID)
assert not agent._is_premium_model("moonshotai/Kimi-K2.6")
assert not agent._is_premium_model("unsupported/model")
@pytest.mark.asyncio
async def test_default_session_uses_configured_default_model():
model = await agent._model_override_for_new_session(None, None)
assert model is None
@pytest.mark.asyncio
async def test_explicit_premium_session_allowed_for_authenticated_user():
model = await agent._model_override_for_new_session(
None,
agent.DEFAULT_PREMIUM_MODEL_ID,
)
assert model == agent.DEFAULT_PREMIUM_MODEL_ID
@pytest.mark.asyncio
async def test_switching_to_premium_model_is_allowed_for_authenticated_user(
monkeypatch,
):
updated = []
async def fake_check_session_access(session_id, user, request=None):
assert session_id == "s1"
assert user["user_id"] == "u1"
return SimpleNamespace(user_id="u1")
async def fake_update_session_model(session_id, model_id):
updated.append((session_id, model_id))
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
monkeypatch.setattr(
agent.session_manager,
"update_session_model",
fake_update_session_model,
)
response = await agent.set_session_model(
"s1",
{"model": agent.DEFAULT_GPT_MODEL_ID},
request=None,
user={"user_id": "u1", "plan": "free"},
)
assert response == {"session_id": "s1", "model": agent.DEFAULT_GPT_MODEL_ID}
assert updated == [("s1", agent.DEFAULT_GPT_MODEL_ID)]
@pytest.mark.asyncio
async def test_switching_to_unknown_model_id_is_rejected(monkeypatch):
async def fake_check_session_access(session_id, user, request=None):
return SimpleNamespace(user_id=user["user_id"])
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
with pytest.raises(HTTPException) as exc_info:
await agent.set_session_model(
"s1",
{"model": "unsupported/model"},
request=None,
user={"user_id": "u1", "plan": "free"},
)
assert exc_info.value.status_code == 400
assert "Unknown model" in exc_info.value.detail
@pytest.mark.asyncio
async def test_premium_quota_charges_without_user_billing_inside_allowance(monkeypatch):
persisted = []
async def fake_persist_session_snapshot(agent_session):
persisted.append(agent_session)
monkeypatch.setattr(
agent.session_manager,
"persist_session_snapshot",
fake_persist_session_snapshot,
)
agent_session = _premium_session()
await agent._enforce_premium_model_quota(
{"user_id": "u1", "plan": "free"},
agent_session,
)
assert agent_session.claude_counted is True
assert agent_session.session.premium_user_billed is False
assert persisted == [agent_session]
assert await agent.user_quotas.get_claude_used_today("u1") == 1
@pytest.mark.asyncio
async def test_free_user_gets_two_subsidized_premium_sessions_then_user_billing(
monkeypatch,
):
async def fake_persist_session_snapshot(_agent_session):
return None
monkeypatch.setattr(
agent.session_manager,
"persist_session_snapshot",
fake_persist_session_snapshot,
)
first = _premium_session()
await agent._enforce_premium_model_quota({"user_id": "g1", "plan": "free"}, first)
assert first.session.premium_user_billed is False
second = _premium_session()
await agent._enforce_premium_model_quota({"user_id": "g1", "plan": "free"}, second)
assert second.session.premium_user_billed is False
third = _premium_session()
await agent._enforce_premium_model_quota({"user_id": "g1", "plan": "free"}, third)
assert third.session.premium_user_billed is True
assert await agent.user_quotas.get_claude_used_today("g1") == 2
@pytest.mark.asyncio
async def test_free_model_does_not_consume_premium_quota(monkeypatch):
async def fail_if_persisted(_agent_session):
raise AssertionError("free model should not consume premium quota")
monkeypatch.setattr(
agent.session_manager,
"persist_session_snapshot",
fail_if_persisted,
)
agent_session = _premium_session("moonshotai/Kimi-K2.6")
await agent._enforce_premium_model_quota(
{"user_id": "u1", "plan": "free"},
agent_session,
)
assert agent_session.claude_counted is False
assert await agent.user_quotas.get_claude_used_today("u1") == 0
@pytest.mark.asyncio
async def test_pro_user_uses_pro_premium_quota(monkeypatch):
async def fake_persist_session_snapshot(_agent_session):
return None
monkeypatch.setattr(
agent.session_manager,
"persist_session_snapshot",
fake_persist_session_snapshot,
)
for index in range(2):
agent_session = _premium_session()
await agent._enforce_premium_model_quota(
{"user_id": "pro-user", "plan": "pro"},
agent_session,
)
assert agent_session.claude_counted is True
assert agent_session.session.premium_user_billed is False
assert await agent.user_quotas.get_claude_used_today("pro-user") == index + 1
@pytest.mark.asyncio
async def test_pro_user_billable_overflow_also_bills_user(monkeypatch):
async def fake_persist(_agent_session):
return None
monkeypatch.setattr(agent.session_manager, "persist_session_snapshot", fake_persist)
monkeypatch.setattr(agent.user_quotas, "daily_cap_for", lambda plan: 1)
await agent._enforce_premium_model_quota(
{"user_id": "p1", "plan": "pro"}, _premium_session()
)
over = _premium_session()
await agent._enforce_premium_model_quota({"user_id": "p1", "plan": "pro"}, over)
assert over.session.premium_user_billed is True
@pytest.mark.asyncio
async def test_restore_summary_enforces_premium_quota_before_seed(monkeypatch):
events = []
agent_session = _premium_session()
class Request:
headers = {}
cookies = {}
async def fake_create_session(**kwargs):
events.append(("create", kwargs["model"]))
return "s1"
async def fake_check_session_access(
session_id, user, request, preload_sandbox=True
):
events.append(("check", session_id, preload_sandbox))
return agent_session
async def fake_enforce_quota(user, session):
assert user["user_id"] == "u1"
assert session is agent_session
session.session.premium_user_billed = True
events.append(("quota", session.session.config.model_name))
async def fake_seed(session_id, messages):
events.append(("seed", session_id, agent_session.session.premium_user_billed))
return len(messages)
monkeypatch.setattr(agent.session_manager, "create_session", fake_create_session)
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
monkeypatch.setattr(agent, "_enforce_premium_model_quota", fake_enforce_quota)
monkeypatch.setattr(agent.session_manager, "seed_from_summary", fake_seed)
response = await agent.restore_session_summary(
Request(),
{"messages": [{"role": "user", "content": "resume this"}]},
{"user_id": "u1", "plan": "free"},
)
assert response.session_id == "s1"
assert events == [
("create", None),
("check", "s1", False),
("quota", agent.DEFAULT_PREMIUM_MODEL_ID),
("seed", "s1", True),
]
@pytest.mark.asyncio
async def test_user_quota_response_uses_premium_fields_only(monkeypatch):
async def fake_get_used_today(user_id):
assert user_id == "u1"
return 2
monkeypatch.setattr(agent.user_quotas, "get_claude_used_today", fake_get_used_today)
monkeypatch.setattr(agent.user_quotas, "daily_cap_for", lambda plan: 5)
response = await agent.get_user_quota({"user_id": "u1", "plan": "pro"})
assert response == {
"plan": "pro",
"premium_used_today": 2,
"premium_daily_cap": 5,
"premium_remaining": 3,
}
@pytest.mark.asyncio
async def test_set_session_yolo_calls_manager_with_cap_presence(monkeypatch):
async def fake_check_session_access(session_id, user, request=None):
assert session_id == "s1"
assert user["user_id"] == "u1"
return object()
calls = []
async def fake_update_session_auto_approval(session_id, **kwargs):
calls.append((session_id, kwargs))
return {
"enabled": kwargs["enabled"],
"cost_cap_usd": 7.5,
"estimated_spend_usd": 0.0,
"remaining_usd": 7.5,
}
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
monkeypatch.setattr(
agent.session_manager,
"update_session_auto_approval",
fake_update_session_auto_approval,
)
response = await agent.set_session_yolo(
"s1",
agent.SessionYoloRequest(enabled=True, cost_cap_usd=7.5),
{"user_id": "u1"},
)
assert response["enabled"] is True
assert response["remaining_usd"] == 7.5
assert calls == [
(
"s1",
{
"enabled": True,
"cost_cap_usd": 7.5,
"cap_provided": True,
},
)
]