AIstudioProxyAPI / tests /integration /test_model_management.py
peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
14 kB
"""
Integration tests for browser_utils/model_management.py and api_utils/model_switching.py.
These tests verify lock hierarchy and concurrent model switching behavior with REAL locks:
- Uses REAL asyncio.Lock from server_state
- Mocks ONLY external boundaries (browser, page)
- Verifies actual concurrency control and serialization
Focus: Ensure model_switching_lock serializes model switches and params_cache_lock is properly nested.
"""
import asyncio
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api_utils.context_types import RequestContext
@pytest.mark.integration
@pytest.mark.asyncio
async def test_concurrent_model_switches_serialized(real_server_state):
"""
测试场景: 两个并发请求尝试切换到不同模型
预期: model_switching_lock 序列化操作,第二个请求等待第一个完成
"""
from api_utils.model_switching import handle_model_switching
execution_order = []
async def mock_switch_model(page, model_id, req_id):
execution_order.append(f"start_{req_id}")
await asyncio.sleep(0.1) # 模拟真实的模型切换延迟
execution_order.append(f"end_{req_id}")
return True
real_server_state.current_ai_studio_model_id = "gemini-1.5-pro"
# 创建两个请求上下文
def create_context(model_id):
return cast(
RequestContext,
{
"needs_model_switching": True,
"model_id_to_use": model_id,
"page": AsyncMock(),
"logger": MagicMock(),
"model_switching_lock": real_server_state.model_switching_lock,
"model_actually_switched": False,
"current_ai_studio_model_id": "gemini-1.5-pro",
},
)
context1 = create_context("gemini-2.0-flash")
context2 = create_context("gemini-1.5-flash")
with patch("browser_utils.switch_ai_studio_model", side_effect=mock_switch_model):
task1 = asyncio.create_task(handle_model_switching("req1", context1))
task2 = asyncio.create_task(handle_model_switching("req2", context2))
await asyncio.gather(task1, task2)
# 验证: 执行顺序证明序列化(一个完全完成后另一个才开始)
assert len(execution_order) == 4
# 找到第一个完成的请求
first_end_idx = min(
execution_order.index("end_req1"), execution_order.index("end_req2")
)
first_end_req = execution_order[first_end_idx].split("_")[1]
first_start_idx = execution_order.index(f"start_{first_end_req}")
assert first_start_idx < first_end_idx
# 验证: 第二个请求的 start 在第一个请求的 end 之后
second_req = "req2" if first_end_req == "req1" else "req1"
second_start_idx = execution_order.index(f"start_{second_req}")
assert second_start_idx >= first_end_idx
# 验证: 两个请求都成功完成
assert context1["model_actually_switched"] is True
assert context2["model_actually_switched"] is True
@pytest.mark.integration
@pytest.mark.asyncio
async def test_model_switch_invalidates_param_cache(real_server_state):
"""
测试场景: 模型切换后,参数缓存被正确清空
预期: handle_parameter_cache 检测到模型切换,清空缓存
"""
from api_utils.model_switching import handle_parameter_cache
# 设置初始状态:缓存中有旧参数
page_params_cache = {
"last_known_model_id_for_params": "gemini-1.5-pro",
"temperature": 0.7,
"max_tokens": 1024,
}
context = cast(
RequestContext,
{
"logger": MagicMock(),
"params_cache_lock": real_server_state.params_cache_lock,
"page_params_cache": page_params_cache,
"current_ai_studio_model_id": "gemini-2.0-flash", # 模型已切换
"model_actually_switched": True, # 标记实际切换
},
)
# 执行参数缓存处理
await handle_parameter_cache("req1", context)
# 验证: 缓存被清空,只保留新模型ID
assert len(page_params_cache) == 1
assert page_params_cache["last_known_model_id_for_params"] == "gemini-2.0-flash"
assert "temperature" not in page_params_cache
assert "max_tokens" not in page_params_cache
@pytest.mark.integration
@pytest.mark.asyncio
async def test_param_cache_waits_for_model_switch(real_server_state):
"""
测试场景: 参数缓存更新等待模型切换完成(锁层级)
预期: params_cache_lock 在 model_switching_lock 释放后才能获取
锁层级: processing_lock > model_switching_lock > params_cache_lock
"""
from api_utils.model_switching import handle_model_switching, handle_parameter_cache
execution_order = []
# Mock page 和 logger
mock_page = AsyncMock()
mock_logger = MagicMock()
# Mock 模型切换,引入延迟
async def mock_switch_model(page, model_id, req_id):
execution_order.append("model_switch_start")
await asyncio.sleep(0.1)
execution_order.append("model_switch_end")
return True
# 设置初始模型
real_server_state.current_ai_studio_model_id = "gemini-1.5-pro"
# 创建请求上下文
switch_context = cast(
RequestContext,
{
"needs_model_switching": True,
"model_id_to_use": "gemini-2.0-flash",
"page": mock_page,
"logger": mock_logger,
"model_switching_lock": real_server_state.model_switching_lock,
"model_actually_switched": False,
"current_ai_studio_model_id": "gemini-1.5-pro",
},
)
cache_context = cast(
RequestContext,
{
"logger": mock_logger,
"params_cache_lock": real_server_state.params_cache_lock,
"page_params_cache": {
"last_known_model_id_for_params": "gemini-1.5-pro",
"temperature": 0.7,
},
"current_ai_studio_model_id": "gemini-2.0-flash",
"model_actually_switched": True,
},
)
async def cache_handler():
execution_order.append("cache_handler_start")
await handle_parameter_cache("req1", cache_context)
execution_order.append("cache_handler_end")
with patch(
"browser_utils.switch_ai_studio_model",
side_effect=mock_switch_model,
):
# 同时启动模型切换和缓存处理
# 注意: 这里我们不直接嵌套锁,而是测试两个独立操作的并发行为
task1 = asyncio.create_task(handle_model_switching("req1", switch_context))
# 稍微延迟启动缓存处理,确保模型切换先获取锁
await asyncio.sleep(0.01)
task2 = asyncio.create_task(cache_handler())
await asyncio.gather(task1, task2)
# 验证: 模型切换完成后,缓存处理才开始
# 由于 params_cache_lock 优先级低于 model_switching_lock,
# 但在这个测试中它们是独立操作,所以只需验证缓存被正确清空
assert (
cache_context["page_params_cache"]["last_known_model_id_for_params"]
== "gemini-2.0-flash"
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_model_switch_no_change_skips_operation(real_server_state):
"""
测试场景: 请求的模型与当前模型相同,跳过切换
预期: model_switching_lock 获取后立即释放,不调用 switch_ai_studio_model
"""
from api_utils.model_switching import handle_model_switching
mock_page = AsyncMock()
mock_logger = MagicMock()
# 设置当前模型
real_server_state.current_ai_studio_model_id = "gemini-1.5-pro"
context = cast(
RequestContext,
{
"needs_model_switching": True,
"model_id_to_use": "gemini-1.5-pro", # 与当前相同
"page": mock_page,
"logger": mock_logger,
"model_switching_lock": real_server_state.model_switching_lock,
"model_actually_switched": False,
"current_ai_studio_model_id": "gemini-1.5-pro",
},
)
with patch("browser_utils.switch_ai_studio_model") as mock_switch:
await handle_model_switching("req1", context)
# 验证: switch_ai_studio_model 未被调用
mock_switch.assert_not_called()
# 验证: model_actually_switched 保持 False
assert context["model_actually_switched"] is False
@pytest.mark.integration
@pytest.mark.asyncio
async def test_model_switch_failure_raises_error(real_server_state):
"""
测试场景: 模型切换失败,抛出 HTTP 422 错误
预期: switch_ai_studio_model 返回 False,触发异常
"""
from api_utils.model_switching import handle_model_switching
from models.exceptions import HTTPException
mock_page = AsyncMock()
mock_logger = MagicMock()
real_server_state.current_ai_studio_model_id = "gemini-1.5-pro"
context = cast(
RequestContext,
{
"needs_model_switching": True,
"model_id_to_use": "gemini-2.0-flash",
"page": mock_page,
"logger": mock_logger,
"model_switching_lock": real_server_state.model_switching_lock,
"model_actually_switched": False,
"current_ai_studio_model_id": "gemini-1.5-pro",
},
)
with patch("browser_utils.switch_ai_studio_model", return_value=False):
# 验证: 抛出 HTTP 422 异常
with pytest.raises(HTTPException) as exc_info:
await handle_model_switching("req1", context)
assert exc_info.value.status_code == 422
assert "gemini-2.0-flash" in exc_info.value.detail
@pytest.mark.integration
@pytest.mark.asyncio
async def test_three_concurrent_model_switches(real_server_state):
"""
测试场景: 三个并发请求切换到不同模型
预期: model_switching_lock 确保严格序列化,无交错执行
"""
from api_utils.model_switching import handle_model_switching
execution_log = []
async def mock_switch(page, model_id, req_id):
execution_log.append(f"{req_id}_start")
await asyncio.sleep(0.05)
execution_log.append(f"{req_id}_end")
return True
real_server_state.current_ai_studio_model_id = "base-model"
contexts = [
cast(
RequestContext,
{
"needs_model_switching": True,
"model_id_to_use": f"model-{i}",
"page": AsyncMock(),
"logger": MagicMock(),
"model_switching_lock": real_server_state.model_switching_lock,
"model_actually_switched": False,
"current_ai_studio_model_id": "base-model",
},
)
for i in range(1, 4)
]
with patch("browser_utils.switch_ai_studio_model", side_effect=mock_switch):
tasks = [
asyncio.create_task(handle_model_switching(f"req{i}", contexts[i - 1]))
for i in range(1, 4)
]
await asyncio.gather(*tasks)
# 验证: 6个事件(3个start + 3个end)
assert len(execution_log) == 6
# 验证: 每个请求的 start 和 end 成对出现
for i in range(1, 4):
start_idx = execution_log.index(f"req{i}_start")
end_idx = execution_log.index(f"req{i}_end")
assert start_idx < end_idx
# 验证: 在 start 和 end 之间,没有其他请求的事件
between_events = execution_log[start_idx + 1 : end_idx]
other_req_events = [e for e in between_events if not e.startswith(f"req{i}")]
assert len(other_req_events) == 0, f"请求 req{i} 的执行被其他请求打断"
@pytest.mark.integration
@pytest.mark.asyncio
async def test_param_cache_cleared_on_model_change_detection(real_server_state):
"""
测试场景: 检测到模型ID变化(即使未标记 model_actually_switched)
预期: 缓存仍然被清空
"""
from api_utils.model_switching import handle_parameter_cache
page_params_cache = {
"last_known_model_id_for_params": "old-model",
"temperature": 0.9,
"top_p": 0.95,
}
context = cast(
RequestContext,
{
"logger": MagicMock(),
"params_cache_lock": real_server_state.params_cache_lock,
"page_params_cache": page_params_cache,
"current_ai_studio_model_id": "new-model", # 模型ID已变化
"model_actually_switched": False, # 未显式标记,但ID不同
},
)
await handle_parameter_cache("req1", context)
# 验证: 缓存被清空,因为检测到模型ID变化
assert page_params_cache["last_known_model_id_for_params"] == "new-model"
assert "temperature" not in page_params_cache
assert "top_p" not in page_params_cache
@pytest.mark.integration
@pytest.mark.asyncio
async def test_param_cache_preserved_when_model_unchanged(real_server_state):
"""
测试场景: 模型未变化,参数缓存保持不变
预期: 缓存内容保留
"""
from api_utils.model_switching import handle_parameter_cache
page_params_cache = {
"last_known_model_id_for_params": "gemini-1.5-pro",
"temperature": 0.7,
"max_tokens": 2048,
}
context = cast(
RequestContext,
{
"logger": MagicMock(),
"params_cache_lock": real_server_state.params_cache_lock,
"page_params_cache": page_params_cache,
"current_ai_studio_model_id": "gemini-1.5-pro", # 模型未变
"model_actually_switched": False,
},
)
await handle_parameter_cache("req1", context)
# 验证: 缓存内容保持不变
assert page_params_cache["last_known_model_id_for_params"] == "gemini-1.5-pro"
assert page_params_cache["temperature"] == 0.7
assert page_params_cache["max_tokens"] == 2048
assert len(page_params_cache) == 3