ecom-qa-bert / tests /test_app.py
rnyx's picture
Initial deploy: BERT QA app
3338b6d
"""
Smoke tests for the Flask app routes.
We mock init_model + predict_qa so tests run in milliseconds without
actually loading BERT. This verifies routing, input validation, DB
persistence, and error handling.
"""
from unittest.mock import patch
import pytest
from src import config, db
@pytest.fixture
def client(tmp_path, monkeypatch):
# Isolate DB to a temp file
monkeypatch.setattr(config, "DB_PATH", str(tmp_path / "test.db"))
monkeypatch.setattr(db, "_initialized", False)
# Disable rate limiting for deterministic tests
monkeypatch.setattr(config, "RATE_LIMIT_ENABLED", False)
# Disable warmup so init_model returns immediately
monkeypatch.setattr(config, "WARMUP_ON_START", False)
# Stub out the model entirely — zero BERT in tests
with patch("src.app.init_model") as mock_init, \
patch("src.app.predict_qa") as mock_predict:
mock_init.return_value = None
mock_predict.return_value = {
"answer": "5000 mAh",
"confidence": 0.87,
"confidence_pct": "87.0%",
"confidence_level": "high",
"answer_start_char": 10,
"answer_end_char": 18,
"context_used": "Battery: 5000 mAh capacity.",
"tokens": [],
"num_tokens": 0,
"inference_time_ms": 42,
}
from src.app import create_app
app = create_app()
app.config["TESTING"] = True
with app.test_client() as c:
yield c
def test_healthz(client):
resp = client.get("/healthz")
assert resp.status_code == 200
body = resp.get_json()
assert body["status"] == "ok"
assert "model" in body
def test_index_renders(client):
resp = client.get("/")
assert resp.status_code == 200
assert b"E-Commerce Product QA" in resp.data
def test_predict_requires_both_fields(client):
resp = client.post("/api/predict", json={"question": "What?"})
assert resp.status_code == 400
assert "required" in resp.get_json()["error"].lower()
def test_predict_rejects_short_context(client):
resp = client.post("/api/predict", json={
"question": "What is the battery?",
"context": "too short",
})
assert resp.status_code == 400
def test_predict_success_and_persists(client):
resp = client.post("/api/predict", json={
"question": "What is the battery?",
"context": "Battery: 5000 mAh capacity. Long enough context here to pass validation.",
"source_url": "https://example.com/x",
"source_type": "amazon",
"product_title": "Phone",
})
assert resp.status_code == 200
body = resp.get_json()
assert body["answer"] == "5000 mAh"
assert body["confidence_level"] == "high"
assert "history_id" in body
# Verify persistence
hist = client.get("/api/history").get_json()
assert len(hist["items"]) == 1
assert hist["items"][0]["product_title"] == "Phone"
def test_scrape_requires_url(client):
resp = client.post("/api/scrape", json={})
assert resp.status_code == 400
def test_history_delete_and_clear(client):
# Create an entry via predict
client.post("/api/predict", json={
"question": "Q?",
"context": "A context that is sufficiently long to pass validation cleanly.",
})
hist = client.get("/api/history").get_json()
entry_id = hist["items"][0]["id"]
# Delete one
resp = client.delete(f"/api/history/{entry_id}")
assert resp.status_code == 200
assert resp.get_json()["deleted"] is True
# Delete non-existent
resp = client.delete("/api/history/99999")
assert resp.status_code == 404
# Clear all
client.post("/api/predict", json={
"question": "Q2?",
"context": "Another context long enough to satisfy the length check cleanly.",
})
resp = client.delete("/api/history")
assert resp.status_code == 200
assert resp.get_json()["cleared"] >= 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])