Spaces:
Running
Running
| import sqlite3 | |
| import database | |
| def _tables(db_path: str): | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| rows = conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" | |
| ).fetchall() | |
| return [r[0] for r in rows] | |
| finally: | |
| conn.close() | |
| def test_init_db_creates_tables(test_db_path): | |
| database.init_db(db_path=str(test_db_path)) | |
| tables = _tables(str(test_db_path)) | |
| assert "requests" in tables | |
| assert "model_stats" in tables | |
| assert "low_confidence_flags" in tables | |
| def test_log_request_creates_row(test_db_path): | |
| database.init_db(db_path=str(test_db_path)) | |
| database.log_request( | |
| db_path=str(test_db_path), | |
| request_id="abc-123", | |
| model_name="lr", | |
| input_text="Fed raises rates", | |
| predicted_label="Business", | |
| predicted_label_id=2, | |
| confidence=0.91, | |
| latency_ms=12.4, | |
| is_batch=False, | |
| ) | |
| history = database.get_request_history(db_path=str(test_db_path), limit=10) | |
| assert len(history) == 1 | |
| assert history[0]["predicted_label"] == "Business" | |
| def test_low_confidence_flag(test_db_path, monkeypatch): | |
| monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60) | |
| database.init_db(db_path=str(test_db_path)) | |
| database.log_request( | |
| db_path=str(test_db_path), | |
| request_id="low-1", | |
| model_name="lr", | |
| input_text="Fed raises rates", | |
| predicted_label="Business", | |
| predicted_label_id=2, | |
| confidence=0.45, | |
| latency_ms=12.4, | |
| is_batch=False, | |
| ) | |
| flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=False) | |
| assert len(flags) == 1 | |
| assert abs(float(flags[0]["confidence"]) - 0.45) < 1e-9 | |
| def test_mark_reviewed(test_db_path, monkeypatch): | |
| monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60) | |
| database.init_db(db_path=str(test_db_path)) | |
| database.log_request( | |
| db_path=str(test_db_path), | |
| request_id="low-2", | |
| model_name="lr", | |
| input_text="Fed raises rates", | |
| predicted_label="Business", | |
| predicted_label_id=2, | |
| confidence=0.45, | |
| latency_ms=12.4, | |
| is_batch=False, | |
| ) | |
| database.mark_reviewed("low-2", note="reviewed", db_path=str(test_db_path)) | |
| flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=True) | |
| assert len(flags) == 1 | |
| assert int(flags[0]["reviewed"]) == 1 | |
| assert flags[0]["review_note"] == "reviewed" | |
| def test_get_summary_empty(test_db_path): | |
| database.init_db(db_path=str(test_db_path)) | |
| summary = database.get_summary(db_path=str(test_db_path), days=7) | |
| assert summary["total_requests"] == 0 | |