Spaces:
Sleeping
Sleeping
File size: 2,158 Bytes
3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 3c4d3f7 040b1a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import io
from typing import List
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms
from image_captioning.config import TrainingConfig, get_device
from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
from image_captioning.model import ImageCaptioningModel
# 1. Initialize App and CORS
app = FastAPI(title="Image Captioning API (HF Space)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 2. Load Model & Assets (Global Scope)
device = get_device()
training_cfg = TrainingConfig(max_caption_length=50)
tokenizer = create_tokenizer()
model = ImageCaptioningModel(training_cfg=training_cfg)
# Load weights
CHECKPOINT_PATH = "best_model.pt"
state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# 3. Preprocessing Pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# 4. API Routes
@app.get("/")
async def root():
return {"message": "API is online. Go to /docs for testing."}
@app.get("/health")
async def health() -> dict:
return {"status": "ok"}
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Preprocess and Move to Device
tensor = preprocess(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
captions: List[str] = model.generate(
images=tensor,
max_length=50,
num_beams=1,
)
return JSONResponse({"caption": captions[0]})
except Exception as exc:
return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"}) |