File size: 2,458 Bytes
40243b5
 
 
 
 
 
 
 
ec85d7b
 
40243b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec85d7b
40243b5
 
 
 
ec85d7b
 
 
 
 
40243b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec85d7b
40243b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec85d7b
40243b5
 
ec85d7b
40243b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
import pickle
import os
import uvicorn
from huggingface_hub import hf_hub_download

# Import from model.py
from model import (
    Vocabulary,
    ResNetEncoder,
    DecoderLSTM,
    ImageCaptioningModel,
    generate_caption,
    transform,
    EMBED_DIM,
    HIDDEN_DIM,
)

app = FastAPI(title="Image Captioning API")

# -------------------------
# Enable CORS
# -------------------------
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# Paths
# -------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")

CHECKPOINT_PATH = hf_hub_download(
    repo_id="VIKRAM989/image-label",
    filename="best_checkpoint.pth"
)

# -------------------------
# Load Vocabulary
# -------------------------
class CustomUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if name == "Vocabulary":
            return Vocabulary
        return super().find_class(module, name)

with open(VOCAB_PATH, "rb") as f:
    vocab = CustomUnpickler(f).load()

vocab_size = len(vocab)

# -------------------------
# Build Model
# -------------------------
encoder = ResNetEncoder(EMBED_DIM)
decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)

model = ImageCaptioningModel(encoder, decoder).to(DEVICE)

# -------------------------
# Load Weights
# -------------------------
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

model.load_state_dict(checkpoint["model_state_dict"])

model.eval()

print("✅ Model Loaded Successfully")

# -------------------------
# Health Check
# -------------------------
@app.get("/")
def root():
    return {"message": "Image Captioning API Running"}

# -------------------------
# Caption Endpoint
# -------------------------
@app.post("/caption")
async def caption_image(file: UploadFile = File(...)):

    contents = await file.read()

    image = Image.open(io.BytesIO(contents)).convert("RGB")

    image = transform(image)

    caption = generate_caption(model, image, vocab)

    return {"caption": caption}


if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=7860)