""" FastAPI app for HuggingFaceTB/SmolVLM-Instruct Supports: text-only prompts, single image, and multi-image inputs. Port: 7860 (HuggingFace Spaces default) """ import io import base64 import logging from contextlib import asynccontextmanager from typing import Optional import torch from PIL import Image from fastapi import FastAPI, File, Form, UploadFile, HTTPException from pydantic import BaseModel from transformers import AutoProcessor, AutoModelForVision2Seq # ── Logging ─────────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ── Config ──────────────────────────────────────────────────────────────────── MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # ── Globals ─────────────────────────────────────────────────────────────────── model = None processor = None @asynccontextmanager async def lifespan(app: FastAPI): global model, processor logger.info(f"Loading {MODEL_ID} on {DEVICE} …") processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForVision2Seq.from_pretrained( MODEL_ID, torch_dtype=DTYPE, _attn_implementation="eager", # swap to "flash_attention_2" on supported GPUs ).to(DEVICE) model.eval() logger.info("SmolVLM ready ✓") yield del model, processor if DEVICE == "cuda": torch.cuda.empty_cache() # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI( title="SmolVLM API", description="Multimodal inference with HuggingFaceTB/SmolVLM-Instruct", version="1.0.0", lifespan=lifespan, ) # ── Helpers ─────────────────────────────────────────────────────────────────── def run_inference( prompt: str, images: Optional[list[Image.Image]] = None, max_new_tokens: int = 512, temperature: float = 0.0, ) -> str: images = images or [] # Build chat message — SmolVLM uses the standard messages format content = [] for img in images: content.append({"type": "image"}) content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] # Apply chat template text_input = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor( text=text_input, images=images if images else None, return_tensors="pt", ).to(DEVICE) generate_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=temperature > 0, ) if temperature > 0: generate_kwargs["temperature"] = temperature with torch.no_grad(): output_ids = model.generate(**generate_kwargs) # Decode only the new tokens input_len = inputs["input_ids"].shape[1] generated = output_ids[0][input_len:] return processor.decode(generated, skip_special_tokens=True).strip() # ── Routes ──────────────────────────────────────────────────────────────────── @app.get("/", tags=["Health"]) def root(): return {"status": "ok", "model": MODEL_ID, "device": DEVICE} @app.get("/health", tags=["Health"]) def health(): return {"model_loaded": model is not None} # ── 1. Text-only ────────────────────────────────────────────────────────────── class TextRequest(BaseModel): prompt: str max_new_tokens: int = 512 temperature: float = 0.0 @app.post("/generate/text", tags=["Inference"]) def generate_text(req: TextRequest): """Plain text prompt — no image required.""" if model is None: raise HTTPException(503, "Model not loaded yet") try: return {"prompt": req.prompt, "response": run_inference( req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, )} except Exception as e: logger.exception("Inference error") raise HTTPException(500, str(e)) # ── 2. Image upload (multipart/form-data) ───────────────────────────────────── @app.post("/generate/vision", tags=["Inference"]) async def generate_vision( prompt: str = Form("Describe the image(s) in detail."), max_new_tokens: int = Form(512), temperature: float = Form(0.0), images: list[UploadFile] = File(default=[]), ): """Upload one or more images with an optional text prompt.""" if model is None: raise HTTPException(503, "Model not loaded yet") pil_images: list[Image.Image] = [] for upload in images: raw = await upload.read() try: pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB")) except Exception: raise HTTPException(400, f"Could not decode image: {upload.filename}") try: response = run_inference( prompt, images=pil_images or None, max_new_tokens=max_new_tokens, temperature=temperature, ) return {"prompt": prompt, "num_images": len(pil_images), "response": response} except Exception as e: logger.exception("Inference error") raise HTTPException(500, str(e)) # ── 3. Base64 images via JSON ───────────────────────────────────────────────── class VisionB64Request(BaseModel): prompt: str = "Describe the image(s) in detail." images_b64: list[str] = [] max_new_tokens: int = 512 temperature: float = 0.0 @app.post("/generate/vision/base64", tags=["Inference"]) def generate_vision_b64(req: VisionB64Request): """Send base64-encoded images inside a JSON body.""" if model is None: raise HTTPException(503, "Model not loaded yet") pil_images: list[Image.Image] = [] for idx, b64str in enumerate(req.images_b64): if "," in b64str: b64str = b64str.split(",", 1)[1] try: raw = base64.b64decode(b64str) pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB")) except Exception: raise HTTPException(400, f"Could not decode base64 image at index {idx}") try: response = run_inference( req.prompt, images=pil_images or None, max_new_tokens=req.max_new_tokens, temperature=req.temperature, ) return {"prompt": req.prompt, "num_images": len(pil_images), "response": response} except Exception as e: logger.exception("Inference error") raise HTTPException(500, str(e)) # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)