Spaces:
Sleeping
Sleeping
| # app.py | |
| from time import perf_counter | |
| from io import BytesIO | |
| from typing import List, Optional, Union | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from pydantic import BaseModel, Field, HttpUrl | |
| from PIL import Image | |
| import uvicorn | |
| from util import get_runner, SmolVLMRunner | |
| app = FastAPI(title="SmolVLM Inference API", version="1.2.0") | |
| _runner: Optional[SmolVLMRunner] = None | |
| class URLRequest(BaseModel): | |
| prompt: str = Field(..., description="Text prompt to accompany the images.") | |
| image_urls: List[HttpUrl] = Field(..., description="List of image URLs.") | |
| max_new_tokens: int = Field(300, ge=1, le=1024) | |
| temperature: Optional[float] = Field(None, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(None, gt=0.0, le=1.0) | |
| class DetectDescribeURLRequest(BaseModel): | |
| image_url: HttpUrl | |
| labels: Union[str, List[str]] | |
| box_threshold: float = 0.40 | |
| text_threshold: float = 0.30 | |
| pad_frac: float = 0.06 | |
| max_new_tokens: int = 160 | |
| return_overlay: bool = True | |
| temperature: Optional[float] = None | |
| top_p: Optional[float] = None | |
| async def _load_model_on_startup(): | |
| global _runner | |
| _runner = get_runner() | |
| def health(): | |
| return {"status": "ok", "model": _runner.model_id if _runner else None} | |
| async def generate_from_files( | |
| prompt: str = Form(...), | |
| images: List[UploadFile] = File(..., description="One or more image files."), | |
| max_new_tokens: int = Form(300), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| ): | |
| if not images: | |
| raise HTTPException(status_code=400, detail="At least one image must be provided.") | |
| t_req_start = perf_counter() | |
| # Read files | |
| t_load_start = perf_counter() | |
| blobs = [] | |
| for f in images: | |
| if not f.content_type or not f.content_type.startswith("image/"): | |
| raise HTTPException(status_code=415, detail=f"Unsupported file type: {f.content_type}") | |
| blobs.append(await f.read()) | |
| pil_images = _runner.load_pil_from_bytes(blobs) | |
| t_load_end = perf_counter() | |
| text, inner_metrics = _runner.generate( | |
| prompt=prompt, | |
| images=pil_images, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| return_stats=True, | |
| ) | |
| t_req_end = perf_counter() | |
| metrics = { | |
| **inner_metrics, | |
| "request_ms": { | |
| "image_load": round((t_load_end - t_load_start) * 1000.0, 2), | |
| "end_to_end": round((t_req_end - t_req_start) * 1000.0, 2), | |
| }, | |
| } | |
| return {"text": text, "metrics": metrics} | |
| async def generate_from_urls(req: URLRequest): | |
| t_req_start = perf_counter() | |
| if len(req.image_urls) == 0: | |
| raise HTTPException(status_code=400, detail="At least one image URL is required.") | |
| t_load_start = perf_counter() | |
| pil_images = _runner.load_pil_from_urls([str(u) for u in req.image_urls]) | |
| t_load_end = perf_counter() | |
| text, inner_metrics = _runner.generate( | |
| prompt=req.prompt, | |
| images=pil_images, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| return_stats=True, | |
| ) | |
| t_req_end = perf_counter() | |
| metrics = { | |
| **inner_metrics, | |
| "request_ms": { | |
| "image_load": round((t_load_end - t_load_start) * 1000.0, 2), | |
| "end_to_end": round((t_req_end - t_req_start) * 1000.0, 2), | |
| }, | |
| } | |
| return {"text": text, "metrics": metrics} | |
| async def detect_describe( | |
| image: UploadFile = File(..., description="One image file (image/*)"), | |
| labels: str = Form(..., description='Comma-separated phrases, e.g. "a man,a dog"'), | |
| box_threshold: float = Form(0.40), | |
| text_threshold: float = Form(0.30), | |
| pad_frac: float = Form(0.06), | |
| max_new_tokens: int = Form(160), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| return_overlay: bool = Form(True), | |
| ): | |
| if not image.content_type or not image.content_type.startswith("image/"): | |
| raise HTTPException(status_code=415, detail=f"Unsupported file type: {image.content_type}") | |
| try: | |
| raw = await image.read() | |
| pil = Image.open(BytesIO(raw)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to read image: {e}") | |
| out = _runner.detect_and_describe( | |
| image=pil, | |
| labels=labels, # comma-separated string OK | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| pad_frac=pad_frac, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| return_overlay=return_overlay, | |
| ) | |
| return out | |
| async def detect_describe_url(req: DetectDescribeURLRequest): | |
| try: | |
| pil = _runner.load_pil_from_urls([str(req.image_url)])[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to fetch image: {e}") | |
| out = _runner.detect_and_describe( | |
| image=pil, | |
| labels=req.labels, | |
| box_threshold=req.box_threshold, | |
| text_threshold=req.text_threshold, | |
| pad_frac=req.pad_frac, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| return_overlay=req.return_overlay, | |
| ) | |
| return out | |
| if __name__ == "__main__": | |
| # Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000) | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False) | |