Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Optional | |
| from fastapi import APIRouter, Depends, Query, Request | |
| from api.classify.schemas import ClassifyRequest, ClassifyResponse, Hit | |
| from api.common.deps import get_classifier, get_registry, get_request_id, resolve_bank | |
| from api.common.logging import log_json, setup_logging | |
| from api.common.image_io import load_image_from_base64 | |
| from api.common.settings import settings | |
| logger = setup_logging() | |
| router = APIRouter(prefix="/api/v1", tags=["classify"]) | |
| def classify( | |
| payload: ClassifyRequest, | |
| request: Request, | |
| request_id: str = Depends(get_request_id), | |
| label_set_hash: Optional[str] = Query(default=None, description="If omitted, uses the default label set."), | |
| classifier=Depends(get_classifier), | |
| registry=Depends(get_registry), | |
| ) -> ClassifyResponse: | |
| bank = resolve_bank(registry, label_set_hash) | |
| image = load_image_from_base64( | |
| payload.image_base64, | |
| max_bytes=settings.max_image_mb * 1024 * 1024, | |
| ) | |
| res = classifier.classify( | |
| bank=bank, | |
| image=image, | |
| domain_top_n=payload.domain_top_n or settings.default_domain_top_n, | |
| top_k=payload.top_k or settings.default_top_k, | |
| ) | |
| log_json( | |
| logger, | |
| event="classify", | |
| request_id=request_id, | |
| label_set_hash=bank.label_set_hash, | |
| model_id=settings.clip_model_id, | |
| domain_top_n=payload.domain_top_n, | |
| top_k=payload.top_k, | |
| chosen_domains=res.chosen_domains, | |
| elapsed_ms=res.timings.total_ms, | |
| elapsed_domain_ms=res.timings.domain_ms, | |
| elapsed_labels_ms=res.timings.labels_ms, | |
| ) | |
| return ClassifyResponse( | |
| label_set_hash=bank.label_set_hash, | |
| model_id=settings.clip_model_id, | |
| domain_hits=[Hit(id=i, score=s) for i, s in res.domain_hits], | |
| chosen_domains=res.chosen_domains, | |
| label_hits=[Hit(id=i, score=s) for i, s in res.label_hits], | |
| elapsed_ms=res.timings.total_ms, | |
| elapsed_domain_ms=res.timings.domain_ms, | |
| elapsed_labels_ms=res.timings.labels_ms, | |
| ) | |