Spaces:
Sleeping
Sleeping
| import io | |
| from typing import List | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| from image_captioning.config import TrainingConfig, get_device | |
| from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer | |
| from image_captioning.model import ImageCaptioningModel | |
| # 1. Initialize App and CORS | |
| app = FastAPI(title="Image Captioning API (HF Space)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 2. Load Model & Assets (Global Scope) | |
| device = get_device() | |
| training_cfg = TrainingConfig(max_caption_length=50) | |
| tokenizer = create_tokenizer() | |
| model = ImageCaptioningModel(training_cfg=training_cfg) | |
| # Load weights | |
| CHECKPOINT_PATH = "best_model.pt" | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| # 3. Preprocessing Pipeline | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| # 4. API Routes | |
| async def root(): | |
| return {"message": "API is online. Go to /docs for testing."} | |
| async def health() -> dict: | |
| return {"status": "ok"} | |
| async def caption_image(file: UploadFile = File(...)) -> JSONResponse: | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Preprocess and Move to Device | |
| tensor = preprocess(image).unsqueeze(0).to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| captions: List[str] = model.generate( | |
| images=tensor, | |
| max_length=50, | |
| num_beams=1, | |
| ) | |
| return JSONResponse({"caption": captions[0]}) | |
| except Exception as exc: | |
| return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"}) |