jebin2 commited on
Commit
b54a58a
·
1 Parent(s): 547f02a

feat: add comprehensive credit service tests (30 tests, all passing)

Browse files
Files changed (1) hide show
  1. tests/test_credit_service.py +491 -0
tests/test_credit_service.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive Tests for Credit Service
3
+
4
+ Tests cover:
5
+ 1. Credit Manager - reserve, confirm, refund operations
6
+ 2. Error pattern matching - refundable vs non-refundable
7
+ 3. Job completion handling
8
+ 4. Credits Router endpoints
9
+ 5. Credit middleware (if needed)
10
+
11
+ Uses mocked database and user models.
12
+ """
13
+ import pytest
14
+ from datetime import datetime
15
+ from unittest.mock import patch, MagicMock, AsyncMock
16
+ from fastapi.testclient import TestClient
17
+
18
+
19
+ # ============================================================================
20
+ # 1. Credit Manager Tests
21
+ # ============================================================================
22
+
23
+ class TestCreditReservation:
24
+ """Test credit reservation functionality."""
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_reserve_credit_success(self):
28
+ """Successfully reserve credits from user balance."""
29
+ from services.credit_service.credit_manager import reserve_credit
30
+
31
+ # Mock user with sufficient credits
32
+ mock_user = MagicMock()
33
+ mock_user.user_id = "usr_123"
34
+ mock_user.credits = 10
35
+
36
+ mock_session = AsyncMock()
37
+
38
+ result = await reserve_credit(mock_session, mock_user, amount=5)
39
+
40
+ assert result == True
41
+ assert mock_user.credits == 5 # 10 - 5
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_reserve_credit_insufficient(self):
45
+ """Cannot reserve more credits than user has."""
46
+ from services.credit_service.credit_manager import reserve_credit
47
+
48
+ mock_user = MagicMock()
49
+ mock_user.user_id = "usr_123"
50
+ mock_user.credits = 3
51
+
52
+ mock_session = AsyncMock()
53
+
54
+ result = await reserve_credit(mock_session, mock_user, amount=5)
55
+
56
+ assert result == False
57
+ assert mock_user.credits == 3 # Unchanged
58
+
59
+ @pytest.mark.asyncio
60
+ async def test_reserve_credit_exact_amount(self):
61
+ """Can reserve exact balance."""
62
+ from services.credit_service.credit_manager import reserve_credit
63
+
64
+ mock_user = MagicMock()
65
+ mock_user.credits = 10
66
+
67
+ mock_session = AsyncMock()
68
+
69
+ result = await reserve_credit(mock_session, mock_user, amount=10)
70
+
71
+ assert result == True
72
+ assert mock_user.credits == 0
73
+
74
+
75
+ class TestCreditConfirmation:
76
+ """Test credit confirmation on job completion."""
77
+
78
+ @pytest.mark.asyncio
79
+ async def test_confirm_credit_clears_reservation(self):
80
+ """Confirming credit clears the reservation tracking."""
81
+ from services.credit_service.credit_manager import confirm_credit
82
+
83
+ mock_job = MagicMock()
84
+ mock_job.job_id = "job_123"
85
+ mock_job.credits_reserved = 5
86
+
87
+ mock_session = AsyncMock()
88
+
89
+ await confirm_credit(mock_session, mock_job)
90
+
91
+ assert mock_job.credits_reserved == 0
92
+
93
+ @pytest.mark.asyncio
94
+ async def test_confirm_credit_no_reservation(self):
95
+ """Confirming when no credits reserved does nothing."""
96
+ from services.credit_service.credit_manager import confirm_credit
97
+
98
+ mock_job = MagicMock()
99
+ mock_job.credits_reserved = 0
100
+
101
+ mock_session = AsyncMock()
102
+
103
+ await confirm_credit(mock_session, mock_job)
104
+
105
+ assert mock_job.credits_reserved == 0
106
+
107
+
108
+ class TestCreditRefund:
109
+ """Test credit refund functionality."""
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_refund_credit_success(self):
113
+ """Successfully refund credits to user."""
114
+ from services.credit_service.credit_manager import refund_credit
115
+ from core.models import User
116
+
117
+ # Mock job with reserved credits
118
+ mock_job = MagicMock()
119
+ mock_job.job_id = "job_123"
120
+ mock_job.user_id = 1
121
+ mock_job.credits_reserved = 5
122
+ mock_job.credits_refunded = False
123
+
124
+ # Mock user
125
+ mock_user = MagicMock(spec=User)
126
+ mock_user.id = 1
127
+ mock_user.user_id = "usr_123"
128
+ mock_user.credits = 10
129
+
130
+ # Mock database session
131
+ mock_session = AsyncMock()
132
+ mock_result = MagicMock()
133
+ mock_result.scalar_one_or_none.return_value = mock_user
134
+ mock_session.execute.return_value = mock_result
135
+
136
+ result = await refund_credit(mock_session, mock_job, "Test refund")
137
+
138
+ assert result == True
139
+ assert mock_user.credits == 15 # 10 + 5
140
+ assert mock_job.credits_reserved == 0
141
+ assert mock_job.credits_refunded == True
142
+
143
+ @pytest.mark.asyncio
144
+ async def test_refund_credit_no_reservation(self):
145
+ """Cannot refund if no credits were reserved."""
146
+ from services.credit_service.credit_manager import refund_credit
147
+
148
+ mock_job = MagicMock()
149
+ mock_job.credits_reserved = 0
150
+
151
+ mock_session = AsyncMock()
152
+
153
+ result = await refund_credit(mock_session, mock_job, "Test")
154
+
155
+ assert result == False
156
+
157
+ @pytest.mark.asyncio
158
+ async def test_refund_credit_already_refunded(self):
159
+ """Cannot refund credits twice."""
160
+ from services.credit_service.credit_manager import refund_credit
161
+
162
+ mock_job = MagicMock()
163
+ mock_job.credits_reserved = 5
164
+ mock_job.credits_refunded = True
165
+
166
+ mock_session = AsyncMock()
167
+
168
+ result = await refund_credit(mock_session, mock_job, "Test")
169
+
170
+ assert result == False
171
+
172
+
173
+ # ============================================================================
174
+ # 2. Error Pattern Matching Tests
175
+ # ============================================================================
176
+
177
+ class TestErrorPatternMatching:
178
+ """Test refundable vs non-refundable error detection."""
179
+
180
+ def test_refundable_api_key_error(self):
181
+ """API key errors are refundable."""
182
+ from services.credit_service.credit_manager import is_refundable_error
183
+
184
+ assert is_refundable_error("API_KEY_INVALID: The API key is invalid") == True
185
+
186
+ def test_refundable_quota_exceeded(self):
187
+ """Quota exceeded is refundable."""
188
+ from services.credit_service.credit_manager import is_refundable_error
189
+
190
+ assert is_refundable_error("QUOTA_EXCEEDED: Daily quota exceeded") == True
191
+
192
+ def test_refundable_internal_error(self):
193
+ """Internal server errors are refundable."""
194
+ from services.credit_service.credit_manager import is_refundable_error
195
+
196
+ assert is_refundable_error("INTERNAL_ERROR: Something went wrong") == True
197
+
198
+ def test_refundable_timeout(self):
199
+ """Timeouts are refundable."""
200
+ from services.credit_service.credit_manager import is_refundable_error
201
+
202
+ assert is_refundable_error("Request TIMEOUT after 30 seconds") == True
203
+
204
+ def test_refundable_500_error(self):
205
+ """HTTP 500 errors are refundable."""
206
+ from services.credit_service.credit_manager import is_refundable_error
207
+
208
+ assert is_refundable_error("Server returned 500 Internal Server Error") == True
209
+
210
+ def test_non_refundable_safety_filter(self):
211
+ """Safety filter blocks are not refundable."""
212
+ from services.credit_service.credit_manager import is_refundable_error
213
+
214
+ assert is_refundable_error("Content blocked by safety filter") == False
215
+
216
+ def test_non_refundable_invalid_input(self):
217
+ """Invalid input errors are not refundable."""
218
+ from services.credit_service.credit_manager import is_refundable_error
219
+
220
+ assert is_refundable_error("INVALID_INPUT: Bad image format") == False
221
+
222
+ def test_non_refundable_400_error(self):
223
+ """HTTP 400 errors are not refundable."""
224
+ from services.credit_service.credit_manager import is_refundable_error
225
+
226
+ assert is_refundable_error("Bad request: 400 status code") == False
227
+
228
+ def test_non_refundable_cancelled(self):
229
+ """User cancellations are not refundable."""
230
+ from services.credit_service.credit_manager import is_refundable_error
231
+
232
+ assert is_refundable_error("User cancelled the operation") == False
233
+
234
+ def test_refundable_max_retries(self):
235
+ """Max retries exceeded is refundable."""
236
+ from services.credit_service.credit_manager import is_refundable_error
237
+
238
+ assert is_refundable_error("Failed after max retries") == True
239
+
240
+ def test_unknown_error_not_refundable(self):
241
+ """Unknown errors default to non-refundable."""
242
+ from services.credit_service.credit_manager import is_refundable_error
243
+
244
+ assert is_refundable_error("Some random unknown error") == False
245
+
246
+ def test_empty_error_not_refundable(self):
247
+ """Empty error message is not refundable."""
248
+ from services.credit_service.credit_manager import is_refundable_error
249
+
250
+ assert is_refundable_error("") == False
251
+ assert is_refundable_error(None) == False
252
+
253
+
254
+ # ============================================================================
255
+ # 3. Job Completion Handling Tests
256
+ # ============================================================================
257
+
258
+ class TestJobCompletionHandling:
259
+ """Test credit handling when jobs complete."""
260
+
261
+ @pytest.mark.asyncio
262
+ async def test_completed_job_confirms_credits(self):
263
+ """Completed jobs confirm credit usage."""
264
+ from services.credit_service.credit_manager import handle_job_completion
265
+
266
+ mock_job = MagicMock()
267
+ mock_job.job_id = "job_123"
268
+ mock_job.status = "completed"
269
+ mock_job.credits_reserved = 5
270
+
271
+ mock_session = AsyncMock()
272
+
273
+ with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm:
274
+ await handle_job_completion(mock_session, mock_job)
275
+ mock_confirm.assert_called_once()
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_failed_refundable_job_refunds(self):
279
+ """Failed jobs with refundable errors get refunds."""
280
+ from services.credit_service.credit_manager import handle_job_completion
281
+
282
+ mock_job = MagicMock()
283
+ mock_job.status = "failed"
284
+ mock_job.error_message = "API_KEY_INVALID: Bad key"
285
+ mock_job.credits_reserved = 5
286
+
287
+ mock_session = AsyncMock()
288
+
289
+ with patch('services.credit_service.credit_manager.refund_credit') as mock_refund:
290
+ await handle_job_completion(mock_session, mock_job)
291
+ mock_refund.assert_called_once()
292
+
293
+ @pytest.mark.asyncio
294
+ async def test_failed_non_refundable_job_keeps_credits(self):
295
+ """Failed jobs with non-refundable errors keep credits."""
296
+ from services.credit_service.credit_manager import handle_job_completion
297
+
298
+ mock_job = MagicMock()
299
+ mock_job.status = "failed"
300
+ mock_job.error_message = "Safety filter blocked content"
301
+ mock_job.credits_reserved = 5
302
+
303
+ mock_session = AsyncMock()
304
+
305
+ with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm:
306
+ await handle_job_completion(mock_session, mock_job)
307
+ mock_confirm.assert_called_once()
308
+
309
+ @pytest.mark.asyncio
310
+ async def test_cancelled_before_start_refunds(self):
311
+ """Cancelled jobs that never started get refunds."""
312
+ from services.credit_service.credit_manager import handle_job_completion
313
+
314
+ mock_job = MagicMock()
315
+ mock_job.status = "cancelled"
316
+ mock_job.started_at = None
317
+ mock_job.credits_reserved = 5
318
+
319
+ mock_session = AsyncMock()
320
+
321
+ with patch('services.credit_service.credit_manager.refund_credit') as mock_refund:
322
+ await handle_job_completion(mock_session, mock_job)
323
+ mock_refund.assert_called_once()
324
+
325
+ @pytest.mark.asyncio
326
+ async def test_cancelled_during_processing_keeps_credits(self):
327
+ """Cancelled jobs that started keep credits."""
328
+ from services.credit_service.credit_manager import handle_job_completion
329
+
330
+ mock_job = MagicMock()
331
+ mock_job.status = "cancelled"
332
+ mock_job.started_at = datetime.utcnow()
333
+ mock_job.credits_reserved = 5
334
+
335
+ mock_session = AsyncMock()
336
+
337
+ with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm:
338
+ await handle_job_completion(mock_session, mock_job)
339
+ mock_confirm.assert_called_once()
340
+
341
+
342
+ # ============================================================================
343
+ # 4. Credits Router Tests
344
+ # ============================================================================
345
+
346
+ class TestCreditsRouter:
347
+ """Test credits API endpoints."""
348
+
349
+ def test_get_balance_requires_auth(self):
350
+ """GET /credits/balance requires authentication."""
351
+ from routers.credits import router
352
+ from fastapi import FastAPI
353
+
354
+ app = FastAPI()
355
+ app.include_router(router)
356
+ client = TestClient(app)
357
+
358
+ response = client.get("/credits/balance")
359
+
360
+ # Should fail without auth
361
+ assert response.status_code in [401, 403, 422, 500]
362
+
363
+ def test_get_balance_returns_user_credits(self):
364
+ """GET /credits/balance returns user's credit balance."""
365
+ from routers.credits import router
366
+ from fastapi import FastAPI
367
+
368
+ app = FastAPI()
369
+
370
+ # Mock authenticated user in request state
371
+ mock_user = MagicMock()
372
+ mock_user.user_id = "usr_123"
373
+ mock_user.credits = 50
374
+ mock_user.last_used_at = None
375
+
376
+ # Create test client with middleware that sets request.state.user
377
+ @app.middleware("http")
378
+ async def add_user_to_state(request, call_next):
379
+ request.state.user = mock_user
380
+ return await call_next(request)
381
+
382
+ app.include_router(router)
383
+ client = TestClient(app)
384
+
385
+ response = client.get("/credits/balance")
386
+
387
+ assert response.status_code == 200
388
+ data = response.json()
389
+ assert data["user_id"] == "usr_123"
390
+ assert data["credits"] == 50
391
+
392
+ def test_get_history_requires_auth(self):
393
+ """GET /credits/history requires authentication."""
394
+ from routers.credits import router
395
+ from fastapi import FastAPI
396
+
397
+ app = FastAPI()
398
+ app.include_router(router)
399
+ client = TestClient(app)
400
+
401
+ response = client.get("/credits/history")
402
+
403
+ # Should fail without auth
404
+ assert response.status_code in [401, 403, 422, 500]
405
+
406
+ def test_get_history_returns_paginated_jobs(self):
407
+ """GET /credits/history returns paginated job list."""
408
+ from routers.credits import router
409
+ from fastapi import FastAPI
410
+ from core.database import get_db
411
+
412
+ app = FastAPI()
413
+
414
+ mock_user = MagicMock()
415
+ mock_user.user_id = "usr_123"
416
+ mock_user.credits = 50
417
+
418
+ # Mock database with jobs
419
+ mock_job = MagicMock()
420
+ mock_job.job_id = "job_123"
421
+ mock_job.job_type = "generate-video"
422
+ mock_job.status = "completed"
423
+ mock_job.credits_reserved = 10
424
+ mock_job.credits_refunded = False
425
+ mock_job.error_message = None
426
+ mock_job.created_at = datetime.utcnow()
427
+ mock_job.completed_at = datetime.utcnow()
428
+
429
+ async def mock_get_db():
430
+ mock_db = AsyncMock()
431
+ mock_result = MagicMock()
432
+ mock_result.scalars.return_value.all.return_value = [mock_job]
433
+ mock_db.execute.return_value = mock_result
434
+ yield mock_db
435
+
436
+ @app.middleware("http")
437
+ async def add_user_to_state(request, call_next):
438
+ request.state.user = mock_user
439
+ return await call_next(request)
440
+
441
+ app.dependency_overrides[get_db] = mock_get_db
442
+ app.include_router(router)
443
+ client = TestClient(app)
444
+
445
+ response = client.get("/credits/history")
446
+
447
+ assert response.status_code == 200
448
+ data = response.json()
449
+ assert data["user_id"] == "usr_123"
450
+ assert data["current_balance"] == 50
451
+ assert len(data["history"]) == 1
452
+ assert data["history"][0]["job_id"] == "job_123"
453
+
454
+ def test_get_history_pagination(self):
455
+ """GET /credits/history supports pagination."""
456
+ from routers.credits import router
457
+ from fastapi import FastAPI
458
+ from core.database import get_db
459
+
460
+ app = FastAPI()
461
+
462
+ mock_user = MagicMock()
463
+ mock_user.user_id = "usr_123"
464
+ mock_user.credits = 50
465
+
466
+ async def mock_get_db():
467
+ mock_db = AsyncMock()
468
+ mock_result = MagicMock()
469
+ mock_result.scalars.return_value.all.return_value = []
470
+ mock_db.execute.return_value = mock_result
471
+ yield mock_db
472
+
473
+ @app.middleware("http")
474
+ async def add_user_to_state(request, call_next):
475
+ request.state.user = mock_user
476
+ return await call_next(request)
477
+
478
+ app.dependency_overrides[get_db] = mock_get_db
479
+ app.include_router(router)
480
+ client = TestClient(app)
481
+
482
+ response = client.get("/credits/history?page=2&limit=10")
483
+
484
+ assert response.status_code == 200
485
+ data = response.json()
486
+ assert data["page"] == 2
487
+ assert data["limit"] == 10
488
+
489
+
490
+ if __name__ == "__main__":
491
+ pytest.main([__file__, "-v"])