from __future__ import annotations def test_request_id_header_roundtrip(client): r = client.get("/docs", headers={"x-request-id": "abc-123"}) # swagger can be large; we only check header assert r.headers.get("x-request-id") == "abc-123" def test_upload_label_set_and_list(client, sample_label_set): r = client.post("/api/v1/label-sets", json=sample_label_set) assert r.status_code == 200 data = r.json() assert "label_set_hash" in data assert data["name"] == "test-v1" assert data["domain_count"] == 2 assert data["label_count"] == 4 r2 = client.get("/api/v1/label-sets") assert r2.status_code == 200 items = r2.json() assert len(items) == 1 assert items[0]["label_set_hash"] == data["label_set_hash"] assert items[0]["is_default"] is True def test_activate_label_set(client, sample_label_set): r = client.post("/api/v1/label-sets", json=sample_label_set) h = r.json()["label_set_hash"] r2 = client.post(f"/api/v1/label-sets/{h}/activate") assert r2.status_code == 200 assert r2.json()["default_label_set_hash"] == h def test_classify_requires_default_if_no_hash(client, tiny_image_b64): # No label set uploaded => no default r = client.post("/api/v1/classify", json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2}) assert r.status_code == 400 assert "No default label set" in r.text def test_classify_with_default_label_set(client, sample_label_set, tiny_image_b64): client.post("/api/v1/label-sets", json=sample_label_set) r = client.post( "/api/v1/classify", json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2}, ) assert r.status_code == 200 data = r.json() assert "label_set_hash" in data assert data["model_id"] # model id string assert len(data["domain_hits"]) == 1 # labels may be empty if domain chosen has no labels (here it does) assert "elapsed_ms" in data assert "elapsed_domain_ms" in data assert "elapsed_labels_ms" in data def test_classify_with_explicit_hash_query_param(client, sample_label_set, tiny_image_b64): r = client.post("/api/v1/label-sets", json=sample_label_set) h = r.json()["label_set_hash"] r2 = client.post( f"/api/v1/classify?label_set_hash={h}", json={"image_base64": tiny_image_b64, "domain_top_n": 2, "top_k": 3}, ) assert r2.status_code == 200 assert r2.json()["label_set_hash"] == h