from __future__ import annotations from typing import Optional from fastapi import APIRouter, Depends, Query, Request from api.classify.schemas import ClassifyRequest, ClassifyResponse, Hit from api.common.deps import get_classifier, get_registry, get_request_id, resolve_bank from api.common.logging import log_json, setup_logging from api.common.image_io import load_image_from_base64 from api.common.settings import settings logger = setup_logging() router = APIRouter(prefix="/api/v1", tags=["classify"]) @router.post("/classify", response_model=ClassifyResponse) def classify( payload: ClassifyRequest, request: Request, request_id: str = Depends(get_request_id), label_set_hash: Optional[str] = Query(default=None, description="If omitted, uses the default label set."), classifier=Depends(get_classifier), registry=Depends(get_registry), ) -> ClassifyResponse: bank = resolve_bank(registry, label_set_hash) image = load_image_from_base64( payload.image_base64, max_bytes=settings.max_image_mb * 1024 * 1024, ) res = classifier.classify( bank=bank, image=image, domain_top_n=payload.domain_top_n or settings.default_domain_top_n, top_k=payload.top_k or settings.default_top_k, ) log_json( logger, event="classify", request_id=request_id, label_set_hash=bank.label_set_hash, model_id=settings.clip_model_id, domain_top_n=payload.domain_top_n, top_k=payload.top_k, chosen_domains=res.chosen_domains, elapsed_ms=res.timings.total_ms, elapsed_domain_ms=res.timings.domain_ms, elapsed_labels_ms=res.timings.labels_ms, ) return ClassifyResponse( label_set_hash=bank.label_set_hash, model_id=settings.clip_model_id, domain_hits=[Hit(id=i, score=s) for i, s in res.domain_hits], chosen_domains=res.chosen_domains, label_hits=[Hit(id=i, score=s) for i, s in res.label_hits], elapsed_ms=res.timings.total_ms, elapsed_domain_ms=res.timings.domain_ms, elapsed_labels_ms=res.timings.labels_ms, )