import io import os import time from contextlib import asynccontextmanager from fastapi import FastAPI, File, Form, HTTPException, Security, UploadFile from fastapi.responses import JSONResponse from fastapi.security import APIKeyHeader from PIL import Image, UnidentifiedImageError from transformers import AutoTokenizer, BlipForConditionalGeneration, BlipImageProcessor, BlipProcessor Image.MAX_IMAGE_PIXELS = None MODEL_ID = os.getenv("MODEL_ID", "Salesforce/blip-image-captioning-large") API_KEY = os.getenv("API_KEY") API_KEY_HEADER_NAME = os.getenv("API_KEY_HEADER_NAME", "x-api-key") USE_FAST_PROCESSOR = os.getenv("USE_FAST_PROCESSOR", "true").strip().lower() in { "1", "true", "yes", "on", } MAX_IMAGE_SIZE = (128,128) processor: BlipProcessor | None = None model: BlipForConditionalGeneration | None = None api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) def load_model() -> None: global processor, model if processor is None: image_processor = BlipImageProcessor.from_pretrained(MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=USE_FAST_PROCESSOR) processor = BlipProcessor(image_processor=image_processor, tokenizer=tokenizer) if model is None: model = BlipForConditionalGeneration.from_pretrained(MODEL_ID) def verify_api_key(api_key: str | None = Security(api_key_header)) -> None: if not API_KEY: raise RuntimeError("API_KEY environment variable is required.") if api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key.") def generate_caption(image_bytes: bytes, min_new_tokens: int, max_new_tokens: int) -> str: if processor is None or model is None: raise RuntimeError("Model is not loaded.") try: raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except UnidentifiedImageError as exc: raise ValueError("Uploaded file is not a valid image.") from exc raw_image.thumbnail(MAX_IMAGE_SIZE) inputs = processor(raw_image, return_tensors="pt") output = model.generate( **inputs, min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens, ) return processor.decode(output[0], skip_special_tokens=True) @asynccontextmanager async def lifespan(_: FastAPI): load_model() yield app = FastAPI( title="BLIP Image Captioning API", description="FastAPI wrapper for Salesforce/blip-image-captioning-large on CPU.", version="1.0.0", lifespan=lifespan, ) @app.get("/") async def root() -> dict[str, str]: return { "message": "BLIP captioning API is running.", "docs": "/docs", "health": "/health", "caption_endpoint": "/caption", } @app.get("/health") async def health() -> dict[str, str]: return { "status": "ok", } @app.post("/caption") async def caption_image( _: None = Security(verify_api_key), image: UploadFile = File(...), min_new_tokens: int = Form(5), max_new_tokens: int = Form(20), ): if min_new_tokens < 1 or max_new_tokens < 1: raise HTTPException(status_code=400, detail="Token values must be at least 1.") if min_new_tokens > max_new_tokens: raise HTTPException( status_code=400, detail="min_new_tokens must be less than or equal to max_new_tokens.", ) image_bytes = await image.read() if not image_bytes: raise HTTPException(status_code=400, detail="Uploaded image is empty.") started_at = time.time() try: caption = generate_caption(image_bytes, min_new_tokens, max_new_tokens) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: return JSONResponse( status_code=500, content={"detail": f"Caption generation failed: {exc}"}, ) elapsed = time.time() - started_at return { "caption": caption, "elapsed_seconds": round(elapsed, 2), }