Image-Captioning / main.py
VIKRAM989's picture
Update main.py
ec85d7b verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
import pickle
import os
import uvicorn
from huggingface_hub import hf_hub_download
# Import from model.py
from model import (
Vocabulary,
ResNetEncoder,
DecoderLSTM,
ImageCaptioningModel,
generate_caption,
transform,
EMBED_DIM,
HIDDEN_DIM,
)
app = FastAPI(title="Image Captioning API")
# -------------------------
# Enable CORS
# -------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Paths
# -------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
CHECKPOINT_PATH = hf_hub_download(
repo_id="VIKRAM989/image-label",
filename="best_checkpoint.pth"
)
# -------------------------
# Load Vocabulary
# -------------------------
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == "Vocabulary":
return Vocabulary
return super().find_class(module, name)
with open(VOCAB_PATH, "rb") as f:
vocab = CustomUnpickler(f).load()
vocab_size = len(vocab)
# -------------------------
# Build Model
# -------------------------
encoder = ResNetEncoder(EMBED_DIM)
decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
# -------------------------
# Load Weights
# -------------------------
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("✅ Model Loaded Successfully")
# -------------------------
# Health Check
# -------------------------
@app.get("/")
def root():
return {"message": "Image Captioning API Running"}
# -------------------------
# Caption Endpoint
# -------------------------
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image = transform(image)
caption = generate_caption(model, image, vocab)
return {"caption": caption}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860)