Spaces:
Running
Running
| from __future__ import annotations | |
| from uuid import uuid4 | |
| from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile | |
| from sqlalchemy import select | |
| from sqlalchemy.dialects.postgresql import insert as pg_insert | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from ..core.config import get_settings | |
| from ..db.models import ImageRow, PredictionRow | |
| from ..db.session import get_session | |
| from ..ml.classify import ( | |
| ClassifyResult, | |
| ensemble_models, | |
| list_runtime_models, | |
| predict, | |
| predict_all_real, | |
| predict_ensemble, | |
| ) | |
| from ..ml.hybrid import predict_hybrid | |
| from ..ml.zeroshot import predict_zeroshot | |
| from ..schemas.predict import HybridResponse, PredictionResponse, ZeroShotResponse | |
| from ..utils.images import load_pil, sha256_hex | |
| router = APIRouter() | |
| settings = get_settings() | |
| async def _persist(session: AsyncSession, *, sha: str, width: int, height: int, | |
| res: ClassifyResult, source: str = "upload") -> tuple[ImageRow, PredictionRow]: | |
| stmt = ( | |
| pg_insert(ImageRow) | |
| .values(id=uuid4(), sha256=sha, source=source, width=width, height=height) | |
| .on_conflict_do_nothing(index_elements=["sha256"]) | |
| ) | |
| await session.execute(stmt) | |
| img_obj = (await session.execute( | |
| select(ImageRow).where(ImageRow.sha256 == sha) | |
| )).scalar_one() | |
| pred = PredictionRow( | |
| id=uuid4(), | |
| image_id=img_obj.id, | |
| model=res.model, | |
| top1_class=res.top1_class, | |
| top1_prob=res.top1_prob, | |
| top5=res.top5, | |
| latency_ms=res.latency_ms, | |
| ) | |
| session.add(pred) | |
| await session.flush() | |
| return img_obj, pred | |
| def _to_resp(image_id, prediction_id, res: ClassifyResult, cache: bool = False) -> dict: | |
| return { | |
| "prediction_id": str(prediction_id), | |
| "image_id": str(image_id), | |
| "model": res.model, | |
| "top1_class": res.top1_class, | |
| "top1_prob": res.top1_prob, | |
| "top5": res.top5, | |
| "latency_ms": res.latency_ms, | |
| "cache": cache, | |
| } | |
| async def info() -> dict: | |
| return { | |
| "real_models": list_runtime_models(), | |
| "ensemble_components": ensemble_models(), | |
| "endpoints": ["/single", "/ensemble", "/hybrid", "/zeroshot", "/all"], | |
| } | |
| async def predict_single( | |
| request: Request, | |
| model: str = "efficientnet_v2_s", | |
| file: UploadFile = File(...), | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| try: | |
| data = await file.read() | |
| img = load_pil(data) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| if model not in list_runtime_models(): | |
| raise HTTPException(status_code=400, detail=f"unknown model: {model}") | |
| sha = sha256_hex(data) | |
| res = predict(model, img) | |
| img_row, pred_row = await _persist( | |
| session, sha=sha, width=img.width, height=img.height, res=res, | |
| ) | |
| await session.commit() | |
| return _to_resp(img_row.id, pred_row.id, res) | |
| async def predict_ensemble_ep( | |
| file: UploadFile = File(...), | |
| mode: str = "uniform", | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| if mode not in {"uniform", "weighted"}: | |
| raise HTTPException(status_code=400, detail="mode must be uniform|weighted") | |
| try: | |
| data = await file.read() | |
| img = load_pil(data) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| sha = sha256_hex(data) | |
| res = predict_ensemble(img, mode=mode) | |
| img_row, pred_row = await _persist( | |
| session, sha=sha, width=img.width, height=img.height, res=res, | |
| ) | |
| await session.commit() | |
| return _to_resp(img_row.id, pred_row.id, res) | |
| async def predict_hybrid_ep( | |
| file: UploadFile = File(...), | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| try: | |
| data = await file.read() | |
| img = load_pil(data) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| sha = sha256_hex(data) | |
| try: | |
| res = predict_hybrid(img) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=503, detail=str(exc)) | |
| proxy = ClassifyResult( | |
| model="hybrid_dinov2_segformer_histgbm", | |
| top1_class=res.top1_class, | |
| top1_prob=res.top1_prob, | |
| top5=res.top5, | |
| latency_ms=res.latency_ms, | |
| logits=[], | |
| ) | |
| img_row, pred_row = await _persist( | |
| session, sha=sha, width=img.width, height=img.height, res=proxy, | |
| ) | |
| await session.commit() | |
| return { | |
| **_to_resp(img_row.id, pred_row.id, proxy), | |
| "attributes": res.attributes, | |
| "embedding_norm": res.embedding_norm, | |
| } | |
| async def predict_zeroshot_ep( | |
| file: UploadFile = File(...), | |
| prompt: str = "a photograph of {}", | |
| session: AsyncSession = Depends(get_session), | |
| ): | |
| try: | |
| data = await file.read() | |
| img = load_pil(data) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| sha = sha256_hex(data) | |
| res = predict_zeroshot(img, prompt_template=prompt) | |
| proxy = ClassifyResult( | |
| model="clip_zeroshot", | |
| top1_class=res.top1_class, | |
| top1_prob=res.top1_prob, | |
| top5=res.top5, | |
| latency_ms=res.latency_ms, | |
| logits=[], | |
| ) | |
| img_row, pred_row = await _persist( | |
| session, sha=sha, width=img.width, height=img.height, res=proxy, | |
| ) | |
| await session.commit() | |
| return {**_to_resp(img_row.id, pred_row.id, proxy), "prompt": res.prompt} | |
| async def predict_all_ep( | |
| file: UploadFile = File(...), | |
| ): | |
| try: | |
| data = await file.read() | |
| img = load_pil(data) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| out = {} | |
| for r in predict_all_real(img): | |
| out[r.model] = { | |
| "top1": r.top1_class, | |
| "prob": r.top1_prob, | |
| "latency_ms": r.latency_ms, | |
| "top5": r.top5, | |
| } | |
| out["ensemble_top3_uniform"] = ( | |
| lambda r: {"top1": r.top1_class, "prob": r.top1_prob, "latency_ms": r.latency_ms, "top5": r.top5} | |
| )(predict_ensemble(img, mode="uniform")) | |
| out["ensemble_top3_weighted"] = ( | |
| lambda r: {"top1": r.top1_class, "prob": r.top1_prob, "latency_ms": r.latency_ms, "top5": r.top5} | |
| )(predict_ensemble(img, mode="weighted")) | |
| return out | |