ayushm98 commited on
Commit
16ccf4e
·
1 Parent(s): c3f5513

Add tests for API schemas and cost tracking

Browse files
Files changed (1) hide show
  1. tests/test_api.py +191 -0
tests/test_api.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the API module."""
2
+
3
+ import pytest
4
+ from unittest.mock import AsyncMock, patch
5
+ from fastapi.testclient import TestClient
6
+
7
+ from cascade.api.schemas import (
8
+ ChatCompletionRequest,
9
+ ChatCompletionResponse,
10
+ ChatMessage,
11
+ UsageInfo,
12
+ ChatCompletionChoice,
13
+ )
14
+
15
+
16
+ class TestSchemas:
17
+ """Tests for API schemas."""
18
+
19
+ def test_chat_message_creation(self):
20
+ """ChatMessage should be created correctly."""
21
+ msg = ChatMessage(role="user", content="Hello")
22
+ assert msg.role == "user"
23
+ assert msg.content == "Hello"
24
+
25
+ def test_chat_message_roles(self):
26
+ """ChatMessage should accept valid roles."""
27
+ for role in ["system", "user", "assistant"]:
28
+ msg = ChatMessage(role=role, content="test")
29
+ assert msg.role == role
30
+
31
+ def test_chat_request_defaults(self):
32
+ """ChatCompletionRequest should have correct defaults."""
33
+ request = ChatCompletionRequest(
34
+ messages=[ChatMessage(role="user", content="Hello")]
35
+ )
36
+ assert request.model == "gpt-4o"
37
+ assert request.temperature == 0.7
38
+ assert request.max_tokens is None
39
+ assert request.stream is False
40
+
41
+ def test_chat_request_custom_values(self):
42
+ """ChatCompletionRequest should accept custom values."""
43
+ request = ChatCompletionRequest(
44
+ model="gpt-4o-mini",
45
+ messages=[ChatMessage(role="user", content="Hello")],
46
+ temperature=0.5,
47
+ max_tokens=100,
48
+ )
49
+ assert request.model == "gpt-4o-mini"
50
+ assert request.temperature == 0.5
51
+ assert request.max_tokens == 100
52
+
53
+ def test_usage_info(self):
54
+ """UsageInfo should track token usage."""
55
+ usage = UsageInfo(
56
+ prompt_tokens=10,
57
+ completion_tokens=20,
58
+ total_tokens=30,
59
+ )
60
+ assert usage.prompt_tokens == 10
61
+ assert usage.completion_tokens == 20
62
+ assert usage.total_tokens == 30
63
+
64
+ def test_chat_response_creation(self):
65
+ """ChatCompletionResponse should be created correctly."""
66
+ response = ChatCompletionResponse(
67
+ id="test-123",
68
+ created=1234567890,
69
+ model="gpt-4o",
70
+ choices=[
71
+ ChatCompletionChoice(
72
+ index=0,
73
+ message=ChatMessage(role="assistant", content="Hi there!"),
74
+ finish_reason="stop",
75
+ )
76
+ ],
77
+ usage=UsageInfo(
78
+ prompt_tokens=5,
79
+ completion_tokens=10,
80
+ total_tokens=15,
81
+ ),
82
+ )
83
+ assert response.id == "test-123"
84
+ assert response.model == "gpt-4o"
85
+ assert len(response.choices) == 1
86
+ assert response.choices[0].message.content == "Hi there!"
87
+
88
+
89
+ class TestCostTracking:
90
+ """Tests for cost tracking."""
91
+
92
+ def test_cost_calculation(self):
93
+ """Cost should be calculated correctly."""
94
+ from cascade.cost.pricing import calculate_cost
95
+
96
+ # GPT-4o: $0.03/1K input, $0.06/1K output
97
+ cost = calculate_cost("gpt-4o", 1000, 1000)
98
+ assert cost == 0.09 # 0.03 + 0.06
99
+
100
+ def test_free_model_cost(self):
101
+ """Free models should have zero cost."""
102
+ from cascade.cost.pricing import calculate_cost, is_free_model
103
+
104
+ cost = calculate_cost("llama3.2", 1000, 1000)
105
+ assert cost == 0.0
106
+ assert is_free_model("llama3.2")
107
+
108
+ def test_savings_calculation(self):
109
+ """Savings should be calculated correctly."""
110
+ from cascade.cost.pricing import calculate_savings
111
+
112
+ dollars_saved, percentage_saved = calculate_savings(1.0, 10.0)
113
+ assert dollars_saved == 9.0
114
+ assert percentage_saved == 90.0
115
+
116
+ def test_savings_zero_baseline(self):
117
+ """Savings with zero baseline should not error."""
118
+ from cascade.cost.pricing import calculate_savings
119
+
120
+ dollars_saved, percentage_saved = calculate_savings(0.0, 0.0)
121
+ assert dollars_saved == 0.0
122
+ assert percentage_saved == 0.0
123
+
124
+
125
+ class TestCostTracker:
126
+ """Tests for the cost tracker service."""
127
+
128
+ def test_tracker_initialization(self):
129
+ """CostTracker should initialize with zero values."""
130
+ from cascade.cost.tracker import CostTracker
131
+
132
+ tracker = CostTracker()
133
+ assert tracker.total_requests == 0
134
+ assert tracker.total_cost == 0.0
135
+ assert tracker.cache_hits_exact == 0
136
+
137
+ def test_record_request(self):
138
+ """Recording requests should update totals."""
139
+ from cascade.cost.tracker import CostTracker
140
+
141
+ tracker = CostTracker()
142
+ tracker.record_request(
143
+ model="gpt-4o",
144
+ prompt_tokens=100,
145
+ completion_tokens=50,
146
+ latency=0.5,
147
+ )
148
+
149
+ assert tracker.total_requests == 1
150
+ assert tracker.total_cost > 0
151
+ assert tracker.total_tokens == 150
152
+
153
+ def test_cache_hit_tracking(self):
154
+ """Cache hits should be tracked correctly."""
155
+ from cascade.cost.tracker import CostTracker
156
+
157
+ tracker = CostTracker()
158
+ tracker.record_cache_hit("exact")
159
+ tracker.record_cache_hit("semantic")
160
+ tracker.record_cache_hit("miss")
161
+
162
+ assert tracker.cache_hits_exact == 1
163
+ assert tracker.cache_hits_semantic == 1
164
+ assert tracker.cache_misses == 1
165
+
166
+ def test_get_summary(self):
167
+ """Summary should contain all metrics."""
168
+ from cascade.cost.tracker import CostTracker
169
+
170
+ tracker = CostTracker()
171
+ tracker.record_request("gpt-4o", 100, 50, 0.5)
172
+ tracker.record_cache_hit("exact")
173
+
174
+ summary = tracker.get_summary()
175
+
176
+ assert "total_requests" in summary
177
+ assert "cost" in summary
178
+ assert "cache" in summary
179
+ assert "latency" in summary
180
+ assert "models" in summary
181
+
182
+ def test_reset(self):
183
+ """Reset should clear all metrics."""
184
+ from cascade.cost.tracker import CostTracker
185
+
186
+ tracker = CostTracker()
187
+ tracker.record_request("gpt-4o", 100, 50, 0.5)
188
+ tracker.reset()
189
+
190
+ assert tracker.total_requests == 0
191
+ assert tracker.total_cost == 0.0