Spaces:
Running
Running
File size: 2,733 Bytes
a229747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import sqlite3
import database
def _tables(db_path: str):
conn = sqlite3.connect(db_path)
try:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
).fetchall()
return [r[0] for r in rows]
finally:
conn.close()
def test_init_db_creates_tables(test_db_path):
database.init_db(db_path=str(test_db_path))
tables = _tables(str(test_db_path))
assert "requests" in tables
assert "model_stats" in tables
assert "low_confidence_flags" in tables
def test_log_request_creates_row(test_db_path):
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="abc-123",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.91,
latency_ms=12.4,
is_batch=False,
)
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 1
assert history[0]["predicted_label"] == "Business"
def test_low_confidence_flag(test_db_path, monkeypatch):
monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60)
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="low-1",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.45,
latency_ms=12.4,
is_batch=False,
)
flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=False)
assert len(flags) == 1
assert abs(float(flags[0]["confidence"]) - 0.45) < 1e-9
def test_mark_reviewed(test_db_path, monkeypatch):
monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60)
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="low-2",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.45,
latency_ms=12.4,
is_batch=False,
)
database.mark_reviewed("low-2", note="reviewed", db_path=str(test_db_path))
flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=True)
assert len(flags) == 1
assert int(flags[0]["reviewed"]) == 1
assert flags[0]["review_note"] == "reviewed"
def test_get_summary_empty(test_db_path):
database.init_db(db_path=str(test_db_path))
summary = database.get_summary(db_path=str(test_db_path), days=7)
assert summary["total_requests"] == 0
|