esandorfi's picture
Apply enhancements
9847fe5
from __future__ import annotations
def test_request_id_header_roundtrip(client):
r = client.get("/docs", headers={"x-request-id": "abc-123"})
# swagger can be large; we only check header
assert r.headers.get("x-request-id") == "abc-123"
def test_upload_label_set_and_list(client, sample_label_set):
r = client.post("/api/v1/label-sets", json=sample_label_set)
assert r.status_code == 200
data = r.json()
assert "label_set_hash" in data
assert data["name"] == "test-v1"
assert data["domain_count"] == 2
assert data["label_count"] == 4
r2 = client.get("/api/v1/label-sets")
assert r2.status_code == 200
items = r2.json()
assert len(items) == 1
assert items[0]["label_set_hash"] == data["label_set_hash"]
assert items[0]["is_default"] is True
def test_activate_label_set(client, sample_label_set):
r = client.post("/api/v1/label-sets", json=sample_label_set)
h = r.json()["label_set_hash"]
r2 = client.post(f"/api/v1/label-sets/{h}/activate")
assert r2.status_code == 200
assert r2.json()["default_label_set_hash"] == h
def test_classify_requires_default_if_no_hash(client, tiny_image_b64):
# No label set uploaded => no default
r = client.post("/api/v1/classify", json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2})
assert r.status_code == 400
assert "No default label set" in r.text
def test_classify_with_default_label_set(client, sample_label_set, tiny_image_b64):
client.post("/api/v1/label-sets", json=sample_label_set)
r = client.post(
"/api/v1/classify",
json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2},
)
assert r.status_code == 200
data = r.json()
assert "label_set_hash" in data
assert data["model_id"] # model id string
assert len(data["domain_hits"]) == 1
# labels may be empty if domain chosen has no labels (here it does)
assert "elapsed_ms" in data
assert "elapsed_domain_ms" in data
assert "elapsed_labels_ms" in data
def test_classify_with_explicit_hash_query_param(client, sample_label_set, tiny_image_b64):
r = client.post("/api/v1/label-sets", json=sample_label_set)
h = r.json()["label_set_hash"]
r2 = client.post(
f"/api/v1/classify?label_set_hash={h}",
json={"image_base64": tiny_image_b64, "domain_top_n": 2, "top_k": 3},
)
assert r2.status_code == 200
assert r2.json()["label_set_hash"] == h