File size: 2,158 Bytes
3c4d3f7
 
 
 
 
 
 
 
 
 
 
 
 
040b1a3
3c4d3f7
040b1a3
3c4d3f7
 
040b1a3
3c4d3f7
 
 
 
 
040b1a3
3c4d3f7
 
 
 
 
040b1a3
3c4d3f7
 
 
040b1a3
 
3c4d3f7
040b1a3
 
 
 
 
 
 
3c4d3f7
040b1a3
 
 
 
3c4d3f7
 
 
 
 
 
 
 
 
 
040b1a3
 
 
3c4d3f7
040b1a3
 
 
 
 
 
 
3c4d3f7
040b1a3
3c4d3f7
040b1a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import io
from typing import List
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms

from image_captioning.config import TrainingConfig, get_device
from image_captioning.dataset import IMAGENET_MEAN, IMAGENET_STD, create_tokenizer
from image_captioning.model import ImageCaptioningModel

# 1. Initialize App and CORS
app = FastAPI(title="Image Captioning API (HF Space)")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 2. Load Model & Assets (Global Scope)
device = get_device()
training_cfg = TrainingConfig(max_caption_length=50)
tokenizer = create_tokenizer()
model = ImageCaptioningModel(training_cfg=training_cfg)

# Load weights
CHECKPOINT_PATH = "best_model.pt"
state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

# 3. Preprocessing Pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# 4. API Routes
@app.get("/")
async def root():
    return {"message": "API is online. Go to /docs for testing."}

@app.get("/health")
async def health() -> dict:
    return {"status": "ok"}

@app.post("/caption")
async def caption_image(file: UploadFile = File(...)) -> JSONResponse:
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        
        # Preprocess and Move to Device
        tensor = preprocess(image).unsqueeze(0).to(device)

        # Inference
        with torch.no_grad():
            captions: List[str] = model.generate(
                images=tensor,
                max_length=50,
                num_beams=1,
            )

        return JSONResponse({"caption": captions[0]})

    except Exception as exc:
        return JSONResponse(status_code=400, content={"error": f"Internal Error: {exc}"})