File size: 2,475 Bytes
7ae1e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations


def test_request_id_header_roundtrip(client):
    r = client.get("/docs", headers={"x-request-id": "abc-123"})
    # swagger can be large; we only check header
    assert r.headers.get("x-request-id") == "abc-123"


def test_upload_label_set_and_list(client, sample_label_set):
    r = client.post("/api/v1/label-sets", json=sample_label_set)
    assert r.status_code == 200
    data = r.json()
    assert "label_set_hash" in data
    assert data["name"] == "test-v1"
    assert data["domain_count"] == 2
    assert data["label_count"] == 4

    r2 = client.get("/api/v1/label-sets")
    assert r2.status_code == 200
    items = r2.json()
    assert len(items) == 1
    assert items[0]["label_set_hash"] == data["label_set_hash"]
    assert items[0]["is_default"] is True


def test_activate_label_set(client, sample_label_set):
    r = client.post("/api/v1/label-sets", json=sample_label_set)
    h = r.json()["label_set_hash"]

    r2 = client.post(f"/api/v1/label-sets/{h}/activate")
    assert r2.status_code == 200
    assert r2.json()["default_label_set_hash"] == h


def test_classify_requires_default_if_no_hash(client, tiny_image_b64):
    # No label set uploaded => no default
    r = client.post("/api/v1/classify", json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2})
    assert r.status_code == 400
    assert "No default label set" in r.text


def test_classify_with_default_label_set(client, sample_label_set, tiny_image_b64):
    client.post("/api/v1/label-sets", json=sample_label_set)

    r = client.post(
        "/api/v1/classify",
        json={"image_base64": tiny_image_b64, "domain_top_n": 1, "top_k": 2},
    )
    assert r.status_code == 200
    data = r.json()
    assert "label_set_hash" in data
    assert data["model_id"]  # model id string
    assert len(data["domain_hits"]) == 1
    # labels may be empty if domain chosen has no labels (here it does)
    assert "elapsed_ms" in data
    assert "elapsed_domain_ms" in data
    assert "elapsed_labels_ms" in data


def test_classify_with_explicit_hash_query_param(client, sample_label_set, tiny_image_b64):
    r = client.post("/api/v1/label-sets", json=sample_label_set)
    h = r.json()["label_set_hash"]

    r2 = client.post(
        f"/api/v1/classify?label_set_hash={h}",
        json={"image_base64": tiny_image_b64, "domain_top_n": 2, "top_k": 3},
    )
    assert r2.status_code == 200
    assert r2.json()["label_set_hash"] == h