Spaces:
Sleeping
Sleeping
File size: 2,458 Bytes
40243b5 ec85d7b 40243b5 ec85d7b 40243b5 ec85d7b 40243b5 ec85d7b 40243b5 ec85d7b 40243b5 ec85d7b 40243b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
import pickle
import os
import uvicorn
from huggingface_hub import hf_hub_download
# Import from model.py
from model import (
Vocabulary,
ResNetEncoder,
DecoderLSTM,
ImageCaptioningModel,
generate_caption,
transform,
EMBED_DIM,
HIDDEN_DIM,
)
app = FastAPI(title="Image Captioning API")
# -------------------------
# Enable CORS
# -------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Paths
# -------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
CHECKPOINT_PATH = hf_hub_download(
repo_id="VIKRAM989/image-label",
filename="best_checkpoint.pth"
)
# -------------------------
# Load Vocabulary
# -------------------------
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == "Vocabulary":
return Vocabulary
return super().find_class(module, name)
with open(VOCAB_PATH, "rb") as f:
vocab = CustomUnpickler(f).load()
vocab_size = len(vocab)
# -------------------------
# Build Model
# -------------------------
encoder = ResNetEncoder(EMBED_DIM)
decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
# -------------------------
# Load Weights
# -------------------------
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("✅ Model Loaded Successfully")
# -------------------------
# Health Check
# -------------------------
@app.get("/")
def root():
return {"message": "Image Captioning API Running"}
# -------------------------
# Caption Endpoint
# -------------------------
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image = transform(image)
caption = generate_caption(model, image, vocab)
return {"caption": caption}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860) |