Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, File, Form, HTTPException, Security, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from fastapi.security import APIKeyHeader | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import AutoTokenizer, BlipForConditionalGeneration, BlipImageProcessor, BlipProcessor | |
| Image.MAX_IMAGE_PIXELS = None | |
| MODEL_ID = os.getenv("MODEL_ID", "Salesforce/blip-image-captioning-large") | |
| API_KEY = os.getenv("API_KEY") | |
| API_KEY_HEADER_NAME = os.getenv("API_KEY_HEADER_NAME", "x-api-key") | |
| USE_FAST_PROCESSOR = os.getenv("USE_FAST_PROCESSOR", "true").strip().lower() in { | |
| "1", | |
| "true", | |
| "yes", | |
| "on", | |
| } | |
| MAX_IMAGE_SIZE = (128,128) | |
| processor: BlipProcessor | None = None | |
| model: BlipForConditionalGeneration | None = None | |
| api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) | |
| def load_model() -> None: | |
| global processor, model | |
| if processor is None: | |
| image_processor = BlipImageProcessor.from_pretrained(MODEL_ID) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=USE_FAST_PROCESSOR) | |
| processor = BlipProcessor(image_processor=image_processor, tokenizer=tokenizer) | |
| if model is None: | |
| model = BlipForConditionalGeneration.from_pretrained(MODEL_ID) | |
| def verify_api_key(api_key: str | None = Security(api_key_header)) -> None: | |
| if not API_KEY: | |
| raise RuntimeError("API_KEY environment variable is required.") | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key.") | |
| def generate_caption(image_bytes: bytes, min_new_tokens: int, max_new_tokens: int) -> str: | |
| if processor is None or model is None: | |
| raise RuntimeError("Model is not loaded.") | |
| try: | |
| raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except UnidentifiedImageError as exc: | |
| raise ValueError("Uploaded file is not a valid image.") from exc | |
| raw_image.thumbnail(MAX_IMAGE_SIZE) | |
| inputs = processor(raw_image, return_tensors="pt") | |
| output = model.generate( | |
| **inputs, | |
| min_new_tokens=min_new_tokens, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| return processor.decode(output[0], skip_special_tokens=True) | |
| async def lifespan(_: FastAPI): | |
| load_model() | |
| yield | |
| app = FastAPI( | |
| title="BLIP Image Captioning API", | |
| description="FastAPI wrapper for Salesforce/blip-image-captioning-large on CPU.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| async def root() -> dict[str, str]: | |
| return { | |
| "message": "BLIP captioning API is running.", | |
| "docs": "/docs", | |
| "health": "/health", | |
| "caption_endpoint": "/caption", | |
| } | |
| async def health() -> dict[str, str]: | |
| return { | |
| "status": "ok", | |
| } | |
| async def caption_image( | |
| _: None = Security(verify_api_key), | |
| image: UploadFile = File(...), | |
| min_new_tokens: int = Form(5), | |
| max_new_tokens: int = Form(20), | |
| ): | |
| if min_new_tokens < 1 or max_new_tokens < 1: | |
| raise HTTPException(status_code=400, detail="Token values must be at least 1.") | |
| if min_new_tokens > max_new_tokens: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="min_new_tokens must be less than or equal to max_new_tokens.", | |
| ) | |
| image_bytes = await image.read() | |
| if not image_bytes: | |
| raise HTTPException(status_code=400, detail="Uploaded image is empty.") | |
| started_at = time.time() | |
| try: | |
| caption = generate_caption(image_bytes, min_new_tokens, max_new_tokens) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": f"Caption generation failed: {exc}"}, | |
| ) | |
| elapsed = time.time() - started_at | |
| return { | |
| "caption": caption, | |
| "elapsed_seconds": round(elapsed, 2), | |
| } | |