samuelolubukun's picture
Upload 3 files
406e5bf verified
"""
FastAPI app for HuggingFaceTB/SmolVLM-Instruct
Supports: text-only prompts, single image, and multi-image inputs.
Port: 7860 (HuggingFace Spaces default)
"""
import io
import base64
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForVision2Seq
# ── Logging ───────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# ── Globals ───────────────────────────────────────────────────────────────────
model = None
processor = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, processor
logger.info(f"Loading {MODEL_ID} on {DEVICE} …")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForVision2Seq.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
_attn_implementation="eager", # swap to "flash_attention_2" on supported GPUs
).to(DEVICE)
model.eval()
logger.info("SmolVLM ready βœ“")
yield
del model, processor
if DEVICE == "cuda":
torch.cuda.empty_cache()
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="SmolVLM API",
description="Multimodal inference with HuggingFaceTB/SmolVLM-Instruct",
version="1.0.0",
lifespan=lifespan,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def run_inference(
prompt: str,
images: Optional[list[Image.Image]] = None,
max_new_tokens: int = 512,
temperature: float = 0.0,
) -> str:
images = images or []
# Build chat message β€” SmolVLM uses the standard messages format
content = []
for img in images:
content.append({"type": "image"})
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
# Apply chat template
text_input = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=text_input,
images=images if images else None,
return_tensors="pt",
).to(DEVICE)
generate_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
)
if temperature > 0:
generate_kwargs["temperature"] = temperature
with torch.no_grad():
output_ids = model.generate(**generate_kwargs)
# Decode only the new tokens
input_len = inputs["input_ids"].shape[1]
generated = output_ids[0][input_len:]
return processor.decode(generated, skip_special_tokens=True).strip()
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/", tags=["Health"])
def root():
return {"status": "ok", "model": MODEL_ID, "device": DEVICE}
@app.get("/health", tags=["Health"])
def health():
return {"model_loaded": model is not None}
# ── 1. Text-only ──────────────────────────────────────────────────────────────
class TextRequest(BaseModel):
prompt: str
max_new_tokens: int = 512
temperature: float = 0.0
@app.post("/generate/text", tags=["Inference"])
def generate_text(req: TextRequest):
"""Plain text prompt β€” no image required."""
if model is None:
raise HTTPException(503, "Model not loaded yet")
try:
return {"prompt": req.prompt, "response": run_inference(
req.prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
)}
except Exception as e:
logger.exception("Inference error")
raise HTTPException(500, str(e))
# ── 2. Image upload (multipart/form-data) ─────────────────────────────────────
@app.post("/generate/vision", tags=["Inference"])
async def generate_vision(
prompt: str = Form("Describe the image(s) in detail."),
max_new_tokens: int = Form(512),
temperature: float = Form(0.0),
images: list[UploadFile] = File(default=[]),
):
"""Upload one or more images with an optional text prompt."""
if model is None:
raise HTTPException(503, "Model not loaded yet")
pil_images: list[Image.Image] = []
for upload in images:
raw = await upload.read()
try:
pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB"))
except Exception:
raise HTTPException(400, f"Could not decode image: {upload.filename}")
try:
response = run_inference(
prompt,
images=pil_images or None,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
return {"prompt": prompt, "num_images": len(pil_images), "response": response}
except Exception as e:
logger.exception("Inference error")
raise HTTPException(500, str(e))
# ── 3. Base64 images via JSON ─────────────────────────────────────────────────
class VisionB64Request(BaseModel):
prompt: str = "Describe the image(s) in detail."
images_b64: list[str] = []
max_new_tokens: int = 512
temperature: float = 0.0
@app.post("/generate/vision/base64", tags=["Inference"])
def generate_vision_b64(req: VisionB64Request):
"""Send base64-encoded images inside a JSON body."""
if model is None:
raise HTTPException(503, "Model not loaded yet")
pil_images: list[Image.Image] = []
for idx, b64str in enumerate(req.images_b64):
if "," in b64str:
b64str = b64str.split(",", 1)[1]
try:
raw = base64.b64decode(b64str)
pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB"))
except Exception:
raise HTTPException(400, f"Could not decode base64 image at index {idx}")
try:
response = run_inference(
req.prompt,
images=pil_images or None,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
)
return {"prompt": req.prompt, "num_images": len(pil_images), "response": response}
except Exception as e:
logger.exception("Inference error")
raise HTTPException(500, str(e))
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)