nexa-classify-api / tests /test_database.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
2.73 kB
import sqlite3
import database
def _tables(db_path: str):
conn = sqlite3.connect(db_path)
try:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
).fetchall()
return [r[0] for r in rows]
finally:
conn.close()
def test_init_db_creates_tables(test_db_path):
database.init_db(db_path=str(test_db_path))
tables = _tables(str(test_db_path))
assert "requests" in tables
assert "model_stats" in tables
assert "low_confidence_flags" in tables
def test_log_request_creates_row(test_db_path):
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="abc-123",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.91,
latency_ms=12.4,
is_batch=False,
)
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 1
assert history[0]["predicted_label"] == "Business"
def test_low_confidence_flag(test_db_path, monkeypatch):
monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60)
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="low-1",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.45,
latency_ms=12.4,
is_batch=False,
)
flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=False)
assert len(flags) == 1
assert abs(float(flags[0]["confidence"]) - 0.45) < 1e-9
def test_mark_reviewed(test_db_path, monkeypatch):
monkeypatch.setattr(database.CFG, "low_confidence_threshold", 0.60)
database.init_db(db_path=str(test_db_path))
database.log_request(
db_path=str(test_db_path),
request_id="low-2",
model_name="lr",
input_text="Fed raises rates",
predicted_label="Business",
predicted_label_id=2,
confidence=0.45,
latency_ms=12.4,
is_batch=False,
)
database.mark_reviewed("low-2", note="reviewed", db_path=str(test_db_path))
flags = database.get_low_confidence_flags(db_path=str(test_db_path), reviewed=True)
assert len(flags) == 1
assert int(flags[0]["reviewed"]) == 1
assert flags[0]["review_note"] == "reviewed"
def test_get_summary_empty(test_db_path):
database.init_db(db_path=str(test_db_path))
summary = database.get_summary(db_path=str(test_db_path), days=7)
assert summary["total_requests"] == 0