Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| def test_request_id_header_roundtrip(client): | |
| r = client.get("/docs", headers={"x-request-id": "abc-123"}) | |
| # swagger can be large; we only check header | |
| assert r.headers.get("x-request-id") == "abc-123" | |
| def test_upload_label_set_and_list(client, sample_label_set): | |
| r = client.post("/api/v1/label-sets", json=sample_label_set) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "label_set_hash" in data | |
| assert data["name"] == "test-v1" | |
| assert data["domain_count"] == 2 | |
| assert data["label_count"] == 4 | |
| r2 = client.get("/api/v1/label-sets") | |
| assert r2.status_code == 200 | |
| items = r2.json() | |
| assert len(items) == 1 | |
| assert items[0]["label_set_hash"] == data["label_set_hash"] | |
| assert items[0]["is_default"] is True | |
| def test_activate_label_set(client, sample_label_set): | |
| r = client.post("/api/v1/label-sets", json=sample_label_set) | |
| h = r.json()["label_set_hash"] | |
| r2 = client.post(f"/api/v1/label-sets/{h}/activate") | |
| assert r2.status_code == 200 | |
| assert r2.json()["default_label_set_hash"] == h | |
| def test_classify_requires_default_if_no_hash(client, tiny_image_b64): | |
| # No label set uploaded => no default | |
| r = client.post("/api/v1/classify", json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2}) | |
| assert r.status_code == 400 | |
| assert "No default label set" in r.text | |
| def test_classify_with_default_label_set(client, sample_label_set, tiny_image_b64): | |
| client.post("/api/v1/label-sets", json=sample_label_set) | |
| r = client.post( | |
| "/api/v1/classify", | |
| json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2}, | |
| ) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "label_set_hash" in data | |
| assert data["model_id"] # model id string | |
| assert len(data["domain_hits"]) == 1 | |
| # labels may be empty if domain chosen has no labels (here it does) | |
| assert "elapsed_ms" in data | |
| assert "elapsed_domain_ms" in data | |
| assert "elapsed_labels_ms" in data | |
| def test_classify_with_explicit_hash_query_param(client, sample_label_set, tiny_image_b64): | |
| r = client.post("/api/v1/label-sets", json=sample_label_set) | |
| h = r.json()["label_set_hash"] | |
| r2 = client.post( | |
| f"/api/v1/classify?label_set_hash={h}", | |
| json={"image_base64": tiny_image_b64, "domain_top_n": 2, "top_k": 3}, | |
| ) | |
| assert r2.status_code == 200 | |
| assert r2.json()["label_set_hash"] == h | |