blip-large / app /main.py
khushalcodiste's picture
fix: added
c54acdc
import io
import os
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, Form, HTTPException, Security, UploadFile
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from PIL import Image, UnidentifiedImageError
from transformers import AutoTokenizer, BlipForConditionalGeneration, BlipImageProcessor, BlipProcessor
Image.MAX_IMAGE_PIXELS = None
MODEL_ID = os.getenv("MODEL_ID", "Salesforce/blip-image-captioning-large")
API_KEY = os.getenv("API_KEY")
API_KEY_HEADER_NAME = os.getenv("API_KEY_HEADER_NAME", "x-api-key")
USE_FAST_PROCESSOR = os.getenv("USE_FAST_PROCESSOR", "true").strip().lower() in {
"1",
"true",
"yes",
"on",
}
MAX_IMAGE_SIZE = (128,128)
processor: BlipProcessor | None = None
model: BlipForConditionalGeneration | None = None
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
def load_model() -> None:
global processor, model
if processor is None:
image_processor = BlipImageProcessor.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=USE_FAST_PROCESSOR)
processor = BlipProcessor(image_processor=image_processor, tokenizer=tokenizer)
if model is None:
model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
def verify_api_key(api_key: str | None = Security(api_key_header)) -> None:
if not API_KEY:
raise RuntimeError("API_KEY environment variable is required.")
if api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key.")
def generate_caption(image_bytes: bytes, min_new_tokens: int, max_new_tokens: int) -> str:
if processor is None or model is None:
raise RuntimeError("Model is not loaded.")
try:
raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except UnidentifiedImageError as exc:
raise ValueError("Uploaded file is not a valid image.") from exc
raw_image.thumbnail(MAX_IMAGE_SIZE)
inputs = processor(raw_image, return_tensors="pt")
output = model.generate(
**inputs,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
)
return processor.decode(output[0], skip_special_tokens=True)
@asynccontextmanager
async def lifespan(_: FastAPI):
load_model()
yield
app = FastAPI(
title="BLIP Image Captioning API",
description="FastAPI wrapper for Salesforce/blip-image-captioning-large on CPU.",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/")
async def root() -> dict[str, str]:
return {
"message": "BLIP captioning API is running.",
"docs": "/docs",
"health": "/health",
"caption_endpoint": "/caption",
}
@app.get("/health")
async def health() -> dict[str, str]:
return {
"status": "ok",
}
@app.post("/caption")
async def caption_image(
_: None = Security(verify_api_key),
image: UploadFile = File(...),
min_new_tokens: int = Form(5),
max_new_tokens: int = Form(20),
):
if min_new_tokens < 1 or max_new_tokens < 1:
raise HTTPException(status_code=400, detail="Token values must be at least 1.")
if min_new_tokens > max_new_tokens:
raise HTTPException(
status_code=400,
detail="min_new_tokens must be less than or equal to max_new_tokens.",
)
image_bytes = await image.read()
if not image_bytes:
raise HTTPException(status_code=400, detail="Uploaded image is empty.")
started_at = time.time()
try:
caption = generate_caption(image_bytes, min_new_tokens, max_new_tokens)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
return JSONResponse(
status_code=500,
content={"detail": f"Caption generation failed: {exc}"},
)
elapsed = time.time() - started_at
return {
"caption": caption,
"elapsed_seconds": round(elapsed, 2),
}