| """ |
| FastAPI app for HuggingFaceTB/SmolVLM-Instruct |
| Supports: text-only prompts, single image, and multi-image inputs. |
| Port: 7860 (HuggingFace Spaces default) |
| """ |
|
|
| import io |
| import base64 |
| import logging |
| from contextlib import asynccontextmanager |
| from typing import Optional |
|
|
| import torch |
| from PIL import Image |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 |
|
|
| |
| model = None |
| processor = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global model, processor |
| logger.info(f"Loading {MODEL_ID} on {DEVICE} β¦") |
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
| model = AutoModelForVision2Seq.from_pretrained( |
| MODEL_ID, |
| torch_dtype=DTYPE, |
| _attn_implementation="eager", |
| ).to(DEVICE) |
| model.eval() |
| logger.info("SmolVLM ready β") |
| yield |
| del model, processor |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| app = FastAPI( |
| title="SmolVLM API", |
| description="Multimodal inference with HuggingFaceTB/SmolVLM-Instruct", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
|
|
| |
|
|
| def run_inference( |
| prompt: str, |
| images: Optional[list[Image.Image]] = None, |
| max_new_tokens: int = 512, |
| temperature: float = 0.0, |
| ) -> str: |
| images = images or [] |
|
|
| |
| content = [] |
| for img in images: |
| content.append({"type": "image"}) |
| content.append({"type": "text", "text": prompt}) |
|
|
| messages = [{"role": "user", "content": content}] |
|
|
| |
| text_input = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
| inputs = processor( |
| text=text_input, |
| images=images if images else None, |
| return_tensors="pt", |
| ).to(DEVICE) |
|
|
| generate_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=temperature > 0, |
| ) |
| if temperature > 0: |
| generate_kwargs["temperature"] = temperature |
|
|
| with torch.no_grad(): |
| output_ids = model.generate(**generate_kwargs) |
|
|
| |
| input_len = inputs["input_ids"].shape[1] |
| generated = output_ids[0][input_len:] |
| return processor.decode(generated, skip_special_tokens=True).strip() |
|
|
|
|
| |
|
|
| @app.get("/", tags=["Health"]) |
| def root(): |
| return {"status": "ok", "model": MODEL_ID, "device": DEVICE} |
|
|
|
|
| @app.get("/health", tags=["Health"]) |
| def health(): |
| return {"model_loaded": model is not None} |
|
|
|
|
| |
|
|
| class TextRequest(BaseModel): |
| prompt: str |
| max_new_tokens: int = 512 |
| temperature: float = 0.0 |
|
|
|
|
| @app.post("/generate/text", tags=["Inference"]) |
| def generate_text(req: TextRequest): |
| """Plain text prompt β no image required.""" |
| if model is None: |
| raise HTTPException(503, "Model not loaded yet") |
| try: |
| return {"prompt": req.prompt, "response": run_inference( |
| req.prompt, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| )} |
| except Exception as e: |
| logger.exception("Inference error") |
| raise HTTPException(500, str(e)) |
|
|
|
|
| |
|
|
| @app.post("/generate/vision", tags=["Inference"]) |
| async def generate_vision( |
| prompt: str = Form("Describe the image(s) in detail."), |
| max_new_tokens: int = Form(512), |
| temperature: float = Form(0.0), |
| images: list[UploadFile] = File(default=[]), |
| ): |
| """Upload one or more images with an optional text prompt.""" |
| if model is None: |
| raise HTTPException(503, "Model not loaded yet") |
|
|
| pil_images: list[Image.Image] = [] |
| for upload in images: |
| raw = await upload.read() |
| try: |
| pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB")) |
| except Exception: |
| raise HTTPException(400, f"Could not decode image: {upload.filename}") |
|
|
| try: |
| response = run_inference( |
| prompt, |
| images=pil_images or None, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| ) |
| return {"prompt": prompt, "num_images": len(pil_images), "response": response} |
| except Exception as e: |
| logger.exception("Inference error") |
| raise HTTPException(500, str(e)) |
|
|
|
|
| |
|
|
| class VisionB64Request(BaseModel): |
| prompt: str = "Describe the image(s) in detail." |
| images_b64: list[str] = [] |
| max_new_tokens: int = 512 |
| temperature: float = 0.0 |
|
|
|
|
| @app.post("/generate/vision/base64", tags=["Inference"]) |
| def generate_vision_b64(req: VisionB64Request): |
| """Send base64-encoded images inside a JSON body.""" |
| if model is None: |
| raise HTTPException(503, "Model not loaded yet") |
|
|
| pil_images: list[Image.Image] = [] |
| for idx, b64str in enumerate(req.images_b64): |
| if "," in b64str: |
| b64str = b64str.split(",", 1)[1] |
| try: |
| raw = base64.b64decode(b64str) |
| pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB")) |
| except Exception: |
| raise HTTPException(400, f"Could not decode base64 image at index {idx}") |
|
|
| try: |
| response = run_inference( |
| req.prompt, |
| images=pil_images or None, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| ) |
| return {"prompt": req.prompt, "num_images": len(pil_images), "response": response} |
| except Exception as e: |
| logger.exception("Inference error") |
| raise HTTPException(500, str(e)) |
|
|
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |
|
|