AIstudioProxyAPI / tests /api_utils /test_model_switching.py
peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
15.7 kB
"""
Tests for api_utils.model_switching module.
Test Strategy:
- Unit tests: Test model analysis and switching logic individually
- Use REAL asyncio.Lock for model_switching_lock and params_cache_lock
- Use real server_state.state object for state management
- Mock only external boundaries: Browser operations (switch_ai_studio_model)
- Integration tests for concurrent model switches:
See tests/integration/test_model_switching_concurrency.py
Coverage Target: 90%+
Mock Budget: <35 (down from 82)
"""
import asyncio
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from api_utils.context_types import RequestContext
from api_utils.model_switching import (
analyze_model_requirements,
handle_model_switching,
handle_parameter_cache,
)
from api_utils.server_state import state
# ==================== Test Classes ====================
class TestAnalyzeModelRequirements:
"""Tests for analyze_model_requirements function."""
@pytest.mark.asyncio
async def test_no_requested_model_returns_unchanged_context(
self, make_request_context
):
"""Test that no requested model leaves context unchanged."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[
{"id": "gemini-1.5-pro"},
{"id": "gemini-1.5-flash"},
],
)
context["model_id_to_use"] = "gemini-1.5-pro"
context["needs_model_switching"] = False
requested_model = "" # No model requested
proxy_model_name = "proxy-model"
result = await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
# Should not modify context when no model requested
assert "needs_model_switching" not in result or not result.get(
"needs_model_switching", False
)
@pytest.mark.asyncio
async def test_requested_model_same_as_proxy_no_switch(self, make_request_context):
"""Test that requesting proxy model doesn't trigger switch."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[{"id": "gemini-1.5-pro"}],
)
requested_model = "proxy-model"
proxy_model_name = "proxy-model"
result = await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
# Proxy model should not trigger switch
assert "needs_model_switching" not in result or not result.get(
"needs_model_switching", False
)
@pytest.mark.asyncio
async def test_requested_different_valid_model_requires_switch(
self, make_request_context
):
"""Test that valid different model triggers switch requirement."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[
{"id": "gemini-1.5-pro"},
{"id": "gemini-1.5-flash"},
],
)
requested_model = "gemini-1.5-flash"
proxy_model_name = "proxy-model"
result = await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
# Should flag model switch needed
assert result.get("needs_model_switching") is True
assert result["model_id_to_use"] == "gemini-1.5-flash"
@pytest.mark.asyncio
async def test_requested_same_model_no_switch(self, make_request_context):
"""Test that requesting current model doesn't trigger switch."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[{"id": "gemini-1.5-pro"}],
)
requested_model = "gemini-1.5-pro"
proxy_model_name = "proxy-model"
result = await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
# Should set model_id but not require switching
assert result["model_id_to_use"] == "gemini-1.5-pro"
assert "needs_model_switching" not in result or not result.get(
"needs_model_switching", False
)
@pytest.mark.asyncio
async def test_invalid_model_raises_bad_request(self, make_request_context):
"""Test that invalid model raises 400 bad request error."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[
{"id": "gemini-1.5-pro"},
{"id": "gemini-1.5-flash"},
],
)
requested_model = "invalid-model"
proxy_model_name = "proxy-model"
with pytest.raises(HTTPException) as exc:
await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
assert exc.value.status_code == 400
assert "Invalid model 'invalid-model'" in exc.value.detail
assert "gemini-1.5-pro" in exc.value.detail # Available models listed
@pytest.mark.asyncio
async def test_no_parsed_model_list_allows_any_model(self, make_request_context):
"""Test that empty model list skips validation."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
parsed_model_list=[], # Empty list
)
requested_model = "any-model"
proxy_model_name = "proxy-model"
# Should not raise error when list is empty
result = await analyze_model_requirements(
req_id, context, requested_model, proxy_model_name
)
assert result["model_id_to_use"] == "any-model"
class TestHandleModelSwitching:
"""Tests for handle_model_switching function."""
@pytest.mark.asyncio
async def test_no_switch_needed_returns_immediately(
self, real_locks_mock_browser, make_request_context
):
"""Test that no switch needed skips lock acquisition."""
req_id = "test-req"
context = make_request_context(
model_switching_lock=real_locks_mock_browser.model_switching_lock
)
context["needs_model_switching"] = False
# Lock should not be locked initially
assert not real_locks_mock_browser.model_switching_lock.locked()
result = await handle_model_switching(req_id, context)
# Should return unchanged context
assert result == context
# Lock should never have been acquired
assert not real_locks_mock_browser.model_switching_lock.locked()
@pytest.mark.asyncio
async def test_successful_model_switch_updates_state(
self, real_locks_mock_browser, make_request_context
):
"""Test successful model switch updates server state."""
req_id = "test-req"
# Save original state
original_model = state.current_ai_studio_model_id
try:
# Set up initial state
state.current_ai_studio_model_id = "gemini-1.5-pro"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
model_switching_lock=real_locks_mock_browser.model_switching_lock,
)
context["needs_model_switching"] = True
context["model_id_to_use"] = "gemini-1.5-flash"
# Mock browser switch operation
with patch(
"browser_utils.switch_ai_studio_model",
new_callable=AsyncMock,
return_value=True,
) as mock_switch:
result = await handle_model_switching(req_id, context)
# Verify browser switch was called
mock_switch.assert_called_once_with(
context["page"], "gemini-1.5-flash", req_id
)
# Verify state was updated
assert state.current_ai_studio_model_id == "gemini-1.5-flash"
assert result["model_actually_switched"] is True
assert result["current_ai_studio_model_id"] == "gemini-1.5-flash"
finally:
# Restore original state
state.current_ai_studio_model_id = original_model
@pytest.mark.asyncio
async def test_failed_model_switch_reverts_state(
self, real_locks_mock_browser, make_request_context
):
"""Test that failed switch reverts state and raises error."""
req_id = "test-req"
# Save original state
original_model = state.current_ai_studio_model_id
try:
# Set up initial state
state.current_ai_studio_model_id = "gemini-1.5-pro"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
model_switching_lock=real_locks_mock_browser.model_switching_lock,
)
context["needs_model_switching"] = True
context["model_id_to_use"] = "gemini-1.5-flash"
# Mock browser switch to fail
with patch(
"browser_utils.switch_ai_studio_model",
new_callable=AsyncMock,
return_value=False,
):
with pytest.raises(HTTPException) as exc:
await handle_model_switching(req_id, context)
# Verify error status and message
assert exc.value.status_code == 422
assert "gemini-1.5-flash" in exc.value.detail
# Verify state was reverted to original
assert state.current_ai_studio_model_id == "gemini-1.5-pro"
finally:
# Restore original state
state.current_ai_studio_model_id = original_model
@pytest.mark.asyncio
async def test_already_switched_model_skips_switch(
self, real_locks_mock_browser, make_request_context
):
"""Test that already-correct model skips switch operation."""
req_id = "test-req"
# Save original state
original_model = state.current_ai_studio_model_id
try:
# State already matches target model
state.current_ai_studio_model_id = "gemini-1.5-flash"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-flash",
model_switching_lock=real_locks_mock_browser.model_switching_lock,
)
context["needs_model_switching"] = True
context["model_id_to_use"] = "gemini-1.5-flash"
with patch(
"browser_utils.switch_ai_studio_model",
new_callable=AsyncMock,
) as mock_switch:
result = await handle_model_switching(req_id, context)
# Should not call browser switch (already correct)
mock_switch.assert_not_called()
# Result should not have switched flag
assert "model_actually_switched" not in result or not result.get(
"model_actually_switched", False
)
finally:
# Restore original state
state.current_ai_studio_model_id = original_model
class TestHandleParameterCache:
"""Tests for handle_parameter_cache function."""
@pytest.mark.asyncio
async def test_cache_cleared_when_model_switched(
self, real_locks_mock_browser, make_request_context
):
"""Test that parameter cache is cleared after model switch."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-flash",
params_cache_lock=real_locks_mock_browser.params_cache_lock,
)
context["model_actually_switched"] = True
context["page_params_cache"] = {
"temperature": 0.7,
"max_tokens": 1024,
"last_known_model_id_for_params": "gemini-1.5-pro",
}
await handle_parameter_cache(req_id, context)
# Cache should be cleared and only new model ID stored
assert context["page_params_cache"] == {
"last_known_model_id_for_params": "gemini-1.5-flash"
}
@pytest.mark.asyncio
async def test_cache_cleared_when_model_differs_from_cache(
self, real_locks_mock_browser, make_request_context
):
"""Test cache cleared when current model differs from cached model."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-flash",
params_cache_lock=real_locks_mock_browser.params_cache_lock,
)
context["model_actually_switched"] = False # Not just switched
context["page_params_cache"] = {
"temperature": 0.7,
"last_known_model_id_for_params": "gemini-1.5-pro", # Different!
}
await handle_parameter_cache(req_id, context)
# Cache should be cleared because model ID mismatch
assert context["page_params_cache"] == {
"last_known_model_id_for_params": "gemini-1.5-flash"
}
@pytest.mark.asyncio
async def test_cache_preserved_when_model_matches(
self, real_locks_mock_browser, make_request_context
):
"""Test cache preserved when model matches cached model."""
req_id = "test-req"
context = make_request_context(
current_ai_studio_model_id="gemini-1.5-pro",
params_cache_lock=real_locks_mock_browser.params_cache_lock,
)
context["model_actually_switched"] = False
context["page_params_cache"] = {
"temperature": 0.7,
"max_tokens": 1024,
"last_known_model_id_for_params": "gemini-1.5-pro", # Matches!
}
await handle_parameter_cache(req_id, context)
# Cache should be preserved
assert context["page_params_cache"]["temperature"] == 0.7
assert context["page_params_cache"]["max_tokens"] == 1024
assert (
context["page_params_cache"]["last_known_model_id_for_params"]
== "gemini-1.5-pro"
)
@pytest.mark.asyncio
async def test_cache_uses_real_lock(self, real_locks_mock_browser):
"""Test that parameter cache uses real asyncio.Lock."""
req_id = "test-req"
lock = real_locks_mock_browser.params_cache_lock
# Verify lock is real asyncio.Lock
assert isinstance(lock, asyncio.Lock)
assert not lock.locked()
context = cast(
RequestContext,
{
"logger": MagicMock(),
"params_cache_lock": lock,
"page_params_cache": {},
"current_ai_studio_model_id": "gemini-1.5-pro",
"model_actually_switched": True,
},
)
# Acquire lock externally to test contention
async with lock:
# Start cache handling in background
cache_task = asyncio.create_task(handle_parameter_cache(req_id, context))
# Give it a moment to try acquiring lock
await asyncio.sleep(0.01)
# Lock should still be held by us
assert lock.locked()
# Wait for cache task to complete (should acquire lock after we release)
await cache_task
# Lock should be released
assert not lock.locked()