File size: 2,160 Bytes
68f48a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

from typing import Optional

from fastapi import APIRouter, Depends, Query, Request

from api.classify.schemas import ClassifyRequest, ClassifyResponse, Hit
from api.common.deps import get_classifier, get_registry, get_request_id, resolve_bank
from api.common.logging import log_json, setup_logging
from api.common.image_io import load_image_from_base64
from api.common.settings import settings


logger = setup_logging()
router = APIRouter(prefix="/api/v1", tags=["classify"])


@router.post("/classify", response_model=ClassifyResponse)
def classify(
    payload: ClassifyRequest,
    request: Request,
    request_id: str = Depends(get_request_id),
    label_set_hash: Optional[str] = Query(default=None, description="If omitted, uses the default label set."),
    classifier=Depends(get_classifier),
    registry=Depends(get_registry),
) -> ClassifyResponse:
    bank = resolve_bank(registry, label_set_hash)

    image = load_image_from_base64(
        payload.image_base64,
        max_bytes=settings.max_image_mb * 1024 * 1024,
    )

    res = classifier.classify(
        bank=bank,
        image=image,
        domain_top_n=payload.domain_top_n or settings.default_domain_top_n,
        top_k=payload.top_k or settings.default_top_k,
    )

    log_json(
        logger,
        event="classify",
        request_id=request_id,
        label_set_hash=bank.label_set_hash,
        model_id=settings.clip_model_id,
        domain_top_n=payload.domain_top_n,
        top_k=payload.top_k,
        chosen_domains=res.chosen_domains,
        elapsed_ms=res.timings.total_ms,
        elapsed_domain_ms=res.timings.domain_ms,
        elapsed_labels_ms=res.timings.labels_ms,
    )

    return ClassifyResponse(
        label_set_hash=bank.label_set_hash,
        model_id=settings.clip_model_id,
        domain_hits=[Hit(id=i, score=s) for i, s in res.domain_hits],
        chosen_domains=res.chosen_domains,
        label_hits=[Hit(id=i, score=s) for i, s in res.label_hits],
        elapsed_ms=res.timings.total_ms,
        elapsed_domain_ms=res.timings.domain_ms,
        elapsed_labels_ms=res.timings.labels_ms,
    )