from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import torch import pickle import os import uvicorn from huggingface_hub import hf_hub_download # Import from model.py from model import ( Vocabulary, ResNetEncoder, DecoderLSTM, ImageCaptioningModel, generate_caption, transform, EMBED_DIM, HIDDEN_DIM, ) app = FastAPI(title="Image Captioning API") # ------------------------- # Enable CORS # ------------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------- # Paths # ------------------------- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl") CHECKPOINT_PATH = hf_hub_download( repo_id="VIKRAM989/image-label", filename="best_checkpoint.pth" ) # ------------------------- # Load Vocabulary # ------------------------- class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): if name == "Vocabulary": return Vocabulary return super().find_class(module, name) with open(VOCAB_PATH, "rb") as f: vocab = CustomUnpickler(f).load() vocab_size = len(vocab) # ------------------------- # Build Model # ------------------------- encoder = ResNetEncoder(EMBED_DIM) decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size) model = ImageCaptioningModel(encoder, decoder).to(DEVICE) # ------------------------- # Load Weights # ------------------------- checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print("✅ Model Loaded Successfully") # ------------------------- # Health Check # ------------------------- @app.get("/") def root(): return {"message": "Image Captioning API Running"} # ------------------------- # Caption Endpoint # ------------------------- @app.post("/caption") async def caption_image(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") image = transform(image) caption = generate_caption(model, image, vocab) return {"caption": caption} if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=7860)