esandorfi's picture
Domain features first reorganisation
68f48a7
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request
from api.classify.schemas import ClassifyRequest, ClassifyResponse, Hit
from api.common.deps import get_classifier, get_registry, get_request_id, resolve_bank
from api.common.logging import log_json, setup_logging
from api.common.image_io import load_image_from_base64
from api.common.settings import settings
logger = setup_logging()
router = APIRouter(prefix="/api/v1", tags=["classify"])
@router.post("/classify", response_model=ClassifyResponse)
def classify(
payload: ClassifyRequest,
request: Request,
request_id: str = Depends(get_request_id),
label_set_hash: Optional[str] = Query(default=None, description="If omitted, uses the default label set."),
classifier=Depends(get_classifier),
registry=Depends(get_registry),
) -> ClassifyResponse:
bank = resolve_bank(registry, label_set_hash)
image = load_image_from_base64(
payload.image_base64,
max_bytes=settings.max_image_mb * 1024 * 1024,
)
res = classifier.classify(
bank=bank,
image=image,
domain_top_n=payload.domain_top_n or settings.default_domain_top_n,
top_k=payload.top_k or settings.default_top_k,
)
log_json(
logger,
event="classify",
request_id=request_id,
label_set_hash=bank.label_set_hash,
model_id=settings.clip_model_id,
domain_top_n=payload.domain_top_n,
top_k=payload.top_k,
chosen_domains=res.chosen_domains,
elapsed_ms=res.timings.total_ms,
elapsed_domain_ms=res.timings.domain_ms,
elapsed_labels_ms=res.timings.labels_ms,
)
return ClassifyResponse(
label_set_hash=bank.label_set_hash,
model_id=settings.clip_model_id,
domain_hits=[Hit(id=i, score=s) for i, s in res.domain_hits],
chosen_domains=res.chosen_domains,
label_hits=[Hit(id=i, score=s) for i, s in res.label_hits],
elapsed_ms=res.timings.total_ms,
elapsed_domain_ms=res.timings.domain_ms,
elapsed_labels_ms=res.timings.labels_ms,
)