Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -6,6 +6,8 @@ import torch
|
|
| 6 |
import pickle
|
| 7 |
import os
|
| 8 |
import uvicorn
|
|
|
|
|
|
|
| 9 |
# Import from model.py
|
| 10 |
from model import (
|
| 11 |
Vocabulary,
|
|
@@ -34,12 +36,16 @@ app.add_middleware(
|
|
| 34 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
# -------------------------
|
| 37 |
-
# Paths
|
| 38 |
# -------------------------
|
| 39 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 40 |
|
| 41 |
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# -------------------------
|
| 45 |
# Load Vocabulary
|
|
@@ -67,6 +73,7 @@ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
|
|
| 67 |
# Load Weights
|
| 68 |
# -------------------------
|
| 69 |
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
|
|
|
| 70 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 71 |
|
| 72 |
model.eval()
|
|
@@ -94,10 +101,8 @@ async def caption_image(file: UploadFile = File(...)):
|
|
| 94 |
|
| 95 |
caption = generate_caption(model, image, vocab)
|
| 96 |
|
| 97 |
-
return {
|
| 98 |
-
"caption": caption
|
| 99 |
-
}
|
| 100 |
|
| 101 |
-
if __name__ == "__main__":
|
| 102 |
|
|
|
|
| 103 |
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|
|
|
|
| 6 |
import pickle
|
| 7 |
import os
|
| 8 |
import uvicorn
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
# Import from model.py
|
| 12 |
from model import (
|
| 13 |
Vocabulary,
|
|
|
|
| 36 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
|
| 38 |
# -------------------------
|
| 39 |
+
# Paths
|
| 40 |
# -------------------------
|
| 41 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 42 |
|
| 43 |
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
|
| 44 |
+
|
| 45 |
+
CHECKPOINT_PATH = hf_hub_download(
|
| 46 |
+
repo_id="VIKRAM989/image-label",
|
| 47 |
+
filename="best_checkpoint.pth"
|
| 48 |
+
)
|
| 49 |
|
| 50 |
# -------------------------
|
| 51 |
# Load Vocabulary
|
|
|
|
| 73 |
# Load Weights
|
| 74 |
# -------------------------
|
| 75 |
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
| 76 |
+
|
| 77 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 78 |
|
| 79 |
model.eval()
|
|
|
|
| 101 |
|
| 102 |
caption = generate_caption(model, image, vocab)
|
| 103 |
|
| 104 |
+
return {"caption": caption}
|
|
|
|
|
|
|
| 105 |
|
|
|
|
| 106 |
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|