Ryanfafa's picture
Update app.py
040b1a3 verified
import io
from typing import List
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms
from image_captioning.config import TrainingConfig, get_device
from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
from image_captioning.model import ImageCaptioningModel
# 1. Initialize App and CORS
app = FastAPI(title="Image Captioning API (HF Space)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 2. Load Model & Assets (Global Scope)
device = get_device()
training_cfg = TrainingConfig(max_caption_length=50)
tokenizer = create_tokenizer()
model = ImageCaptioningModel(training_cfg=training_cfg)
# Load weights
CHECKPOINT_PATH = "best_model.pt"
state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# 3. Preprocessing Pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# 4. API Routes
@app.get("/")
async def root():
return {"message": "API is online. Go to /docs for testing."}
@app.get("/health")
async def health() -> dict:
return {"status": "ok"}
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Preprocess and Move to Device
tensor = preprocess(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
captions: List[str] = model.generate(
images=tensor,
max_length=50,
num_beams=1,
)
return JSONResponse({"caption": captions[0]})
except Exception as exc:
return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"})