|
|
"""Tests for the API module.""" |
|
|
|
|
|
import pytest |
|
|
from unittest.mock import AsyncMock, patch |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from cascade.api.schemas import ( |
|
|
ChatCompletionRequest, |
|
|
ChatCompletionResponse, |
|
|
ChatMessage, |
|
|
UsageInfo, |
|
|
ChatCompletionChoice, |
|
|
) |
|
|
|
|
|
|
|
|
class TestSchemas: |
|
|
"""Tests for API schemas.""" |
|
|
|
|
|
def test_chat_message_creation(self): |
|
|
"""ChatMessage should be created correctly.""" |
|
|
msg = ChatMessage(role="user", content="Hello") |
|
|
assert msg.role == "user" |
|
|
assert msg.content == "Hello" |
|
|
|
|
|
def test_chat_message_roles(self): |
|
|
"""ChatMessage should accept valid roles.""" |
|
|
for role in ["system", "user", "assistant"]: |
|
|
msg = ChatMessage(role=role, content="test") |
|
|
assert msg.role == role |
|
|
|
|
|
def test_chat_request_defaults(self): |
|
|
"""ChatCompletionRequest should have correct defaults.""" |
|
|
request = ChatCompletionRequest( |
|
|
messages=[ChatMessage(role="user", content="Hello")] |
|
|
) |
|
|
assert request.model == "gpt-4o" |
|
|
assert request.temperature == 0.7 |
|
|
assert request.max_tokens is None |
|
|
assert request.stream is False |
|
|
|
|
|
def test_chat_request_custom_values(self): |
|
|
"""ChatCompletionRequest should accept custom values.""" |
|
|
request = ChatCompletionRequest( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ChatMessage(role="user", content="Hello")], |
|
|
temperature=0.5, |
|
|
max_tokens=100, |
|
|
) |
|
|
assert request.model == "gpt-4o-mini" |
|
|
assert request.temperature == 0.5 |
|
|
assert request.max_tokens == 100 |
|
|
|
|
|
def test_usage_info(self): |
|
|
"""UsageInfo should track token usage.""" |
|
|
usage = UsageInfo( |
|
|
prompt_tokens=10, |
|
|
completion_tokens=20, |
|
|
total_tokens=30, |
|
|
) |
|
|
assert usage.prompt_tokens == 10 |
|
|
assert usage.completion_tokens == 20 |
|
|
assert usage.total_tokens == 30 |
|
|
|
|
|
def test_chat_response_creation(self): |
|
|
"""ChatCompletionResponse should be created correctly.""" |
|
|
response = ChatCompletionResponse( |
|
|
id="test-123", |
|
|
created=1234567890, |
|
|
model="gpt-4o", |
|
|
choices=[ |
|
|
ChatCompletionChoice( |
|
|
index=0, |
|
|
message=ChatMessage(role="assistant", content="Hi there!"), |
|
|
finish_reason="stop", |
|
|
) |
|
|
], |
|
|
usage=UsageInfo( |
|
|
prompt_tokens=5, |
|
|
completion_tokens=10, |
|
|
total_tokens=15, |
|
|
), |
|
|
) |
|
|
assert response.id == "test-123" |
|
|
assert response.model == "gpt-4o" |
|
|
assert len(response.choices) == 1 |
|
|
assert response.choices[0].message.content == "Hi there!" |
|
|
|
|
|
|
|
|
class TestCostTracking: |
|
|
"""Tests for cost tracking.""" |
|
|
|
|
|
def test_cost_calculation(self): |
|
|
"""Cost should be calculated correctly.""" |
|
|
from cascade.cost.pricing import calculate_cost |
|
|
|
|
|
|
|
|
cost = calculate_cost("gpt-4o", 1000, 1000) |
|
|
assert cost == 0.09 |
|
|
|
|
|
def test_free_model_cost(self): |
|
|
"""Free models should have zero cost.""" |
|
|
from cascade.cost.pricing import calculate_cost, is_free_model |
|
|
|
|
|
cost = calculate_cost("llama3.2", 1000, 1000) |
|
|
assert cost == 0.0 |
|
|
assert is_free_model("llama3.2") |
|
|
|
|
|
def test_savings_calculation(self): |
|
|
"""Savings should be calculated correctly.""" |
|
|
from cascade.cost.pricing import calculate_savings |
|
|
|
|
|
dollars_saved, percentage_saved = calculate_savings(1.0, 10.0) |
|
|
assert dollars_saved == 9.0 |
|
|
assert percentage_saved == 90.0 |
|
|
|
|
|
def test_savings_zero_baseline(self): |
|
|
"""Savings with zero baseline should not error.""" |
|
|
from cascade.cost.pricing import calculate_savings |
|
|
|
|
|
dollars_saved, percentage_saved = calculate_savings(0.0, 0.0) |
|
|
assert dollars_saved == 0.0 |
|
|
assert percentage_saved == 0.0 |
|
|
|
|
|
|
|
|
class TestCostTracker: |
|
|
"""Tests for the cost tracker service.""" |
|
|
|
|
|
def test_tracker_initialization(self): |
|
|
"""CostTracker should initialize with zero values.""" |
|
|
from cascade.cost.tracker import CostTracker |
|
|
|
|
|
tracker = CostTracker() |
|
|
assert tracker.total_requests == 0 |
|
|
assert tracker.total_cost == 0.0 |
|
|
assert tracker.cache_hits_exact == 0 |
|
|
|
|
|
def test_record_request(self): |
|
|
"""Recording requests should update totals.""" |
|
|
from cascade.cost.tracker import CostTracker |
|
|
|
|
|
tracker = CostTracker() |
|
|
tracker.record_request( |
|
|
model="gpt-4o", |
|
|
prompt_tokens=100, |
|
|
completion_tokens=50, |
|
|
latency=0.5, |
|
|
) |
|
|
|
|
|
assert tracker.total_requests == 1 |
|
|
assert tracker.total_cost > 0 |
|
|
assert tracker.total_tokens == 150 |
|
|
|
|
|
def test_cache_hit_tracking(self): |
|
|
"""Cache hits should be tracked correctly.""" |
|
|
from cascade.cost.tracker import CostTracker |
|
|
|
|
|
tracker = CostTracker() |
|
|
tracker.record_cache_hit("exact") |
|
|
tracker.record_cache_hit("semantic") |
|
|
tracker.record_cache_hit("miss") |
|
|
|
|
|
assert tracker.cache_hits_exact == 1 |
|
|
assert tracker.cache_hits_semantic == 1 |
|
|
assert tracker.cache_misses == 1 |
|
|
|
|
|
def test_get_summary(self): |
|
|
"""Summary should contain all metrics.""" |
|
|
from cascade.cost.tracker import CostTracker |
|
|
|
|
|
tracker = CostTracker() |
|
|
tracker.record_request("gpt-4o", 100, 50, 0.5) |
|
|
tracker.record_cache_hit("exact") |
|
|
|
|
|
summary = tracker.get_summary() |
|
|
|
|
|
assert "total_requests" in summary |
|
|
assert "cost" in summary |
|
|
assert "cache" in summary |
|
|
assert "latency" in summary |
|
|
assert "models" in summary |
|
|
|
|
|
def test_reset(self): |
|
|
"""Reset should clear all metrics.""" |
|
|
from cascade.cost.tracker import CostTracker |
|
|
|
|
|
tracker = CostTracker() |
|
|
tracker.record_request("gpt-4o", 100, 50, 0.5) |
|
|
tracker.reset() |
|
|
|
|
|
assert tracker.total_requests == 0 |
|
|
assert tracker.total_cost == 0.0 |
|
|
|