import io from typing import List from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image import torch from torchvision import transforms from image_captioning.config import TrainingConfig, get_device from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer from image_captioning.model import ImageCaptioningModel # 1. Initialize App and CORS app = FastAPI(title="Image Captioning API (HF Space)") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 2. Load Model & Assets (Global Scope) device = get_device() training_cfg = TrainingConfig(max_caption_length=50) tokenizer = create_tokenizer() model = ImageCaptioningModel(training_cfg=training_cfg) # Load weights CHECKPOINT_PATH = "best_model.pt" state_dict = torch.load(CHECKPOINT_PATH, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() # 3. Preprocessing Pipeline preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # 4. API Routes @app.get("/") async def root(): return {"message": "API is online. Go to /docs for testing."} @app.get("/health") async def health() -> dict: return {"status": "ok"} @app.post("/caption") async def caption_image(file: UploadFile = File(...)) -> JSONResponse: try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Preprocess and Move to Device tensor = preprocess(image).unsqueeze(0).to(device) # Inference with torch.no_grad(): captions: List[str] = model.generate( images=tensor, max_length=50, num_beams=1, ) return JSONResponse({"caption": captions[0]}) except Exception as exc: return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"})