Image-Captioning / main.py
VIKRAM989's picture
Add application file
40243b5
raw
history blame
2.4 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
import pickle
import os
import uvicorn
# Import from model.py
from model import (
Vocabulary,
ResNetEncoder,
DecoderLSTM,
ImageCaptioningModel,
generate_caption,
transform,
EMBED_DIM,
HIDDEN_DIM,
)
app = FastAPI(title="Image Captioning API")
# -------------------------
# Enable CORS
# -------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Paths (relative to main.py)
# -------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
CHECKPOINT_PATH = os.path.join(BASE_DIR, "best_checkpoint.pth")
# -------------------------
# Load Vocabulary
# -------------------------
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == "Vocabulary":
return Vocabulary
return super().find_class(module, name)
with open(VOCAB_PATH, "rb") as f:
vocab = CustomUnpickler(f).load()
vocab_size = len(vocab)
# -------------------------
# Build Model
# -------------------------
encoder = ResNetEncoder(EMBED_DIM)
decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
# -------------------------
# Load Weights
# -------------------------
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("✅ Model Loaded Successfully")
# -------------------------
# Health Check
# -------------------------
@app.get("/")
def root():
return {"message": "Image Captioning API Running"}
# -------------------------
# Caption Endpoint
# -------------------------
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image = transform(image)
caption = generate_caption(model, image, vocab)
return {
"caption": caption
}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860)