Image-captioner / caption_app.py
Chand11's picture
Update caption_app.py
538317c verified
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from PIL import Image
import io
import base64
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
# Initialize FastAPI app
app = FastAPI()
# Setup static and templates directories
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Load BLIP model and processor with local caching
cache_dir = "./model_cache"
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=cache_dir)
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=cache_dir)
@app.get("/", response_class=HTMLResponse)
async def main(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/", response_class=HTMLResponse)
async def caption(request: Request, file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Generate caption using BLIP
inputs = processor(images=image, return_tensors="pt")
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
# Convert image to base64 for preview
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return templates.TemplateResponse("index.html", {
"request": request,
"caption": caption,
"image_data": img_str
})