File size: 5,076 Bytes
6172a47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from api.app import app
from api.dependencies import get_settings
from config.settings import Settings
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_settings():
settings = Settings()
settings.fast_prefix_detection = True
settings.enable_network_probe_mock = True
settings.enable_title_generation_skip = True
return settings
def test_create_message_fast_prefix_detection(client, mock_settings):
app.dependency_overrides[get_settings] = lambda: mock_settings
payload = {
"model": "claude-3-sonnet",
"max_tokens": 100,
"messages": [{"role": "user", "content": "What is the prefix?"}],
}
with (
patch(
"api.optimization_handlers.is_prefix_detection_request",
return_value=(True, "/ask"),
),
patch(
"api.optimization_handlers.extract_command_prefix",
return_value="/ask",
),
):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
data = response.json()
assert "/ask" in data["content"][0]["text"]
app.dependency_overrides.clear()
def test_create_message_quota_check_mock(client, mock_settings):
app.dependency_overrides[get_settings] = lambda: mock_settings
payload = {
"model": "claude-3-sonnet",
"max_tokens": 100,
"messages": [{"role": "user", "content": "quota check"}],
}
with patch("api.optimization_handlers.is_quota_check_request", return_value=True):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
assert "Quota check passed" in response.json()["content"][0]["text"]
app.dependency_overrides.clear()
def test_create_message_title_generation_skip(client, mock_settings):
app.dependency_overrides[get_settings] = lambda: mock_settings
payload = {
"model": "claude-3-sonnet",
"max_tokens": 100,
"messages": [{"role": "user", "content": "generate title"}],
}
with patch(
"api.optimization_handlers.is_title_generation_request", return_value=True
):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
assert "Conversation" in response.json()["content"][0]["text"]
app.dependency_overrides.clear()
def test_create_message_empty_messages_returns_400(client):
"""POST /v1/messages with messages: [] returns 400 invalid_request_error."""
payload = {
"model": "claude-3-sonnet",
"max_tokens": 100,
"messages": [],
}
response = client.post("/v1/messages", json=payload)
assert response.status_code == 400
data = response.json()
assert data.get("type") == "error"
assert data.get("error", {}).get("type") == "invalid_request_error"
assert "cannot be empty" in data.get("error", {}).get("message", "")
def test_count_tokens_endpoint(client):
payload = {
"model": "claude-3-sonnet",
"messages": [{"role": "user", "content": "hello"}],
}
with patch("api.routes.get_token_count", return_value=5):
response = client.post("/v1/messages/count_tokens", json=payload)
assert response.status_code == 200
assert response.json()["input_tokens"] == 5
def test_count_tokens_error_returns_500(client):
"""When get_token_count raises, count_tokens returns 500."""
payload = {
"model": "claude-3-sonnet",
"messages": [{"role": "user", "content": "hello"}],
}
with patch("api.routes.get_token_count", side_effect=RuntimeError("token error")):
response = client.post("/v1/messages/count_tokens", json=payload)
assert response.status_code == 500
assert "token error" in response.json()["detail"]
def test_stop_cli_with_handler(client):
mock_handler = MagicMock()
# Mock the async method to return a completed future or just mock it since TestClient
# will run the app in a way that respects it?
# Actually, we need to mock it as an async function.
mock_handler.stop_all_tasks = AsyncMock(return_value=3)
app.state.message_handler = mock_handler
response = client.post("/stop")
assert response.status_code == 200
assert response.json()["cancelled_count"] == 3
mock_handler.stop_all_tasks.assert_called_once()
# Cleanup state
if hasattr(app.state, "message_handler"):
del app.state.message_handler
def test_stop_cli_fallback_to_manager(client):
if hasattr(app.state, "message_handler"):
del app.state.message_handler
mock_manager = MagicMock()
mock_manager.stop_all = AsyncMock()
app.state.cli_manager = mock_manager
response = client.post("/stop")
assert response.status_code == 200
assert response.json()["source"] == "cli_manager"
mock_manager.stop_all.assert_called_once()
if hasattr(app.state, "cli_manager"):
del app.state.cli_manager
|