VIKRAM989 commited on
Commit
ec85d7b
·
verified ·
1 Parent(s): d75e81d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -6
main.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  import pickle
7
  import os
8
  import uvicorn
 
 
9
  # Import from model.py
10
  from model import (
11
  Vocabulary,
@@ -34,12 +36,16 @@ app.add_middleware(
34
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  # -------------------------
37
- # Paths (relative to main.py)
38
  # -------------------------
39
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
40
 
41
  VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
42
- CHECKPOINT_PATH = os.path.join(BASE_DIR, "best_checkpoint.pth")
 
 
 
 
43
 
44
  # -------------------------
45
  # Load Vocabulary
@@ -67,6 +73,7 @@ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
67
  # Load Weights
68
  # -------------------------
69
  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
 
70
  model.load_state_dict(checkpoint["model_state_dict"])
71
 
72
  model.eval()
@@ -94,10 +101,8 @@ async def caption_image(file: UploadFile = File(...)):
94
 
95
  caption = generate_caption(model, image, vocab)
96
 
97
- return {
98
- "caption": caption
99
- }
100
 
101
- if __name__ == "__main__":
102
 
 
103
  uvicorn.run("main:app", host="0.0.0.0", port=7860)
 
6
  import pickle
7
  import os
8
  import uvicorn
9
+ from huggingface_hub import hf_hub_download
10
+
11
  # Import from model.py
12
  from model import (
13
  Vocabulary,
 
36
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
  # -------------------------
39
+ # Paths
40
  # -------------------------
41
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
42
 
43
  VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
44
+
45
+ CHECKPOINT_PATH = hf_hub_download(
46
+ repo_id="VIKRAM989/image-label",
47
+ filename="best_checkpoint.pth"
48
+ )
49
 
50
  # -------------------------
51
  # Load Vocabulary
 
73
  # Load Weights
74
  # -------------------------
75
  checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
76
+
77
  model.load_state_dict(checkpoint["model_state_dict"])
78
 
79
  model.eval()
 
101
 
102
  caption = generate_caption(model, image, vocab)
103
 
104
+ return {"caption": caption}
 
 
105
 
 
106
 
107
+ if __name__ == "__main__":
108
  uvicorn.run("main:app", host="0.0.0.0", port=7860)