skodan commited on
Commit
5d50d54
·
1 Parent(s): b1aa711

fixing incorrect references

Browse files
api.py CHANGED
@@ -1,11 +1,8 @@
1
- # api.py
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image
5
  from typing import List
6
  from pydantic import BaseModel
7
- from models.resnet_lstm_attention.loader import load_captioning_model
8
- from models.resnet_lstm_attention.cap_mod_defs import EncoderCNN
9
 
10
  from model_registry import get_model
11
  from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
@@ -23,7 +20,6 @@ class InferenceRequest(BaseModel):
23
  model_name: str
24
  top_k: int = 5
25
 
26
- #@app.post("/caption", response_model=CaptionResult)
27
  @app.post("/caption")
28
  async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
29
  image = Image.open(file.file).convert("RGB")
@@ -31,7 +27,6 @@ async def caption_image(model_name: str = Form(...), file: UploadFile = File(...
31
  caption = model.generate_caption(image)
32
  return {"caption": caption}
33
 
34
- #@app.post("/search/text2img", response_model=List[ImageResult])
35
  @app.post("/search/text2img")
36
  async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
37
  model = get_model(model_name)
@@ -45,7 +40,6 @@ async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...
45
  results = model.image_to_text(image, top_k)
46
  return results
47
 
48
- #@app.post("/search/img2img", response_model=List[ImageResult])
49
  @app.post("/search/img2img")
50
  async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
51
  image = Image.open(file.file).convert("RGB")
@@ -61,4 +55,71 @@ async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_
61
 
62
  @app.get("/health")
63
  def health_check():
64
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
  from typing import List
5
  from pydantic import BaseModel
 
 
6
 
7
  from model_registry import get_model
8
  from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
 
20
  model_name: str
21
  top_k: int = 5
22
 
 
23
  @app.post("/caption")
24
  async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
25
  image = Image.open(file.file).convert("RGB")
 
27
  caption = model.generate_caption(image)
28
  return {"caption": caption}
29
 
 
30
  @app.post("/search/text2img")
31
  async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
32
  model = get_model(model_name)
 
40
  results = model.image_to_text(image, top_k)
41
  return results
42
 
 
43
  @app.post("/search/img2img")
44
  async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
45
  image = Image.open(file.file).convert("RGB")
 
55
 
56
  @app.get("/health")
57
  def health_check():
58
+ return {"status": "healthy"}
59
+
60
+
61
+
62
+ # # api.py
63
+ # from fastapi import FastAPI, UploadFile, File, Form
64
+ # from fastapi.middleware.cors import CORSMiddleware
65
+ # from PIL import Image
66
+ # from typing import List
67
+ # from pydantic import BaseModel
68
+ # from models.resnet_lstm_attention.loader import load_captioning_model
69
+ # from models.resnet_lstm_attention.cap_mod_defs import EncoderCNN
70
+
71
+ # from model_registry import get_model
72
+ # from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
73
+
74
+ # app = FastAPI(title="Multimodal Retrieval & Captioning API")
75
+
76
+ # app.add_middleware(
77
+ # CORSMiddleware,
78
+ # allow_origins=["*"],
79
+ # allow_methods=["*"],
80
+ # allow_headers=["*"],
81
+ # )
82
+
83
+ # class InferenceRequest(BaseModel):
84
+ # model_name: str
85
+ # top_k: int = 5
86
+
87
+ # #@app.post("/caption", response_model=CaptionResult)
88
+ # @app.post("/caption")
89
+ # async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
90
+ # image = Image.open(file.file).convert("RGB")
91
+ # model = get_model(model_name)
92
+ # caption = model.generate_caption(image)
93
+ # return {"caption": caption}
94
+
95
+ # #@app.post("/search/text2img", response_model=List[ImageResult])
96
+ # @app.post("/search/text2img")
97
+ # async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
98
+ # model = get_model(model_name)
99
+ # results = model.text_to_image(query, top_k)
100
+ # return results
101
+
102
+ # @app.post("/search/img2text")
103
+ # async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
104
+ # image = Image.open(file.file).convert("RGB")
105
+ # model = get_model(model_name)
106
+ # results = model.image_to_text(image, top_k)
107
+ # return results
108
+
109
+ # #@app.post("/search/img2img", response_model=List[ImageResult])
110
+ # @app.post("/search/img2img")
111
+ # async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
112
+ # image = Image.open(file.file).convert("RGB")
113
+ # model = get_model(model_name)
114
+ # results = model.image_to_image(image, top_k)
115
+ # return results
116
+
117
+ # @app.post("/search/text2text")
118
+ # async def text_to_text(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
119
+ # model = get_model(model_name)
120
+ # results = model.text_to_text(query, top_k)
121
+ # return results
122
+
123
+ # @app.get("/health")
124
+ # def health_check():
125
+ # return {"status": "healthy"}
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import streamlit as st
3
  import requests
4
  import subprocess
@@ -9,10 +8,6 @@ import base64 # For displaying retrieved images if needed
9
  import socket
10
 
11
  # Start FastAPI server in background
12
- # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
13
- # time.sleep(2) # Wait for server to start
14
-
15
- # Check if port is free
16
  def is_port_free(port):
17
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
18
  return s.connect_ex(('localhost', port)) != 0
@@ -57,17 +52,6 @@ with tab_caption:
57
  else:
58
  st.error("Error: " + resp.text)
59
 
60
- # with tab_text2img:
61
- # if text_input and st.button("Search Images"):
62
- # data = {"model_name": model_name, "query": text_input, "top_k": top_k}
63
- # resp = requests.post(f"{API_BASE}/search/text2img", data=data)
64
- # if resp.status_code == 200:
65
- # results = resp.json()
66
- # for res in results:
67
- # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
68
- # else:
69
- # st.error("Error: " + resp.text)
70
-
71
  with tab_text2img:
72
  if text_input and st.button("Search Images"):
73
  data = {"model_name": model_name, "query": text_input, "top_k": top_k}
@@ -82,10 +66,8 @@ with tab_text2img:
82
  if res["image"] is not None:
83
  st.image(res["image"], width=200)
84
  st.caption(f"Score: {res['score']:.3f}")
85
- if "caption" in res: # if you add caption to results later
86
- st.write(res["caption"])
87
  else:
88
- st.caption(f"Score: {res['score']:.3f} (Image not found)")
89
  else:
90
  st.info("No results found.")
91
  else:
@@ -97,28 +79,21 @@ with tab_img2text:
97
  data = {"model_name": model_name, "top_k": top_k}
98
  resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
99
  if resp.status_code == 200:
100
- st.write("Retrieved Texts:", resp.json())
 
 
 
 
 
 
101
  else:
102
- st.error("Error: " + resp.text)
103
-
104
- # with tab_img2img:
105
- # if image_file and st.button("Retrieve Similar Images"):
106
- # files = {"file": image_file.getvalue()}
107
- # data = {"model_name": model_name, "top_k": top_k}
108
- # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
109
- # if resp.status_code == 200:
110
- # results = resp.json()
111
- # for res in results:
112
- # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
113
- # else:
114
- # st.error("Error: " + resp.text)
115
 
116
  with tab_img2img:
117
  if image_file and st.button("Retrieve Similar Images"):
118
  files = {"file": image_file.getvalue()}
119
  data = {"model_name": model_name, "top_k": top_k}
120
  resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
121
-
122
  if resp.status_code == 200:
123
  results = resp.json()
124
  if results:
@@ -126,18 +101,15 @@ with tab_img2img:
126
  cols = st.columns(3)
127
  for idx, res in enumerate(results):
128
  with cols[idx % 3]:
129
- if "image" in res and res["image"] is not None:
130
- st.image(
131
- res["image"],
132
- width=200, # recommended instead of use_column_width
133
- caption=f"Score: {res['score']:.3f}"
134
- )
135
  else:
136
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
137
  else:
138
- st.info("No similar images found in the dataset.")
139
  else:
140
- st.error(f"Backend error: {resp.status_code} - {resp.text}")
141
 
142
  with tab_text2text:
143
  text_input_tt = st.text_input("Enter text to find similar captions",
@@ -156,4 +128,167 @@ with tab_text2text:
156
  else:
157
  st.info("No similar captions found.")
158
  else:
159
- st.error(f"Error: {resp.status_code} - {resp.text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import requests
3
  import subprocess
 
8
  import socket
9
 
10
  # Start FastAPI server in background
 
 
 
 
11
  def is_port_free(port):
12
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
13
  return s.connect_ex(('localhost', port)) != 0
 
52
  else:
53
  st.error("Error: " + resp.text)
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  with tab_text2img:
56
  if text_input and st.button("Search Images"):
57
  data = {"model_name": model_name, "query": text_input, "top_k": top_k}
 
66
  if res["image"] is not None:
67
  st.image(res["image"], width=200)
68
  st.caption(f"Score: {res['score']:.3f}")
 
 
69
  else:
70
+ st.caption(f"Score: {res['score']:.3f} (Image not available)")
71
  else:
72
  st.info("No results found.")
73
  else:
 
79
  data = {"model_name": model_name, "top_k": top_k}
80
  resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
81
  if resp.status_code == 200:
82
+ results = resp.json()
83
+ if results:
84
+ st.subheader("Retrieved Texts:")
85
+ for idx, caption in enumerate(results, 1):
86
+ st.markdown(f"**{idx}.** {caption}")
87
+ else:
88
+ st.info("No results found.")
89
  else:
90
+ st.error(f"Error: {resp.status_code} - {resp.text}")
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  with tab_img2img:
93
  if image_file and st.button("Retrieve Similar Images"):
94
  files = {"file": image_file.getvalue()}
95
  data = {"model_name": model_name, "top_k": top_k}
96
  resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
 
97
  if resp.status_code == 200:
98
  results = resp.json()
99
  if results:
 
101
  cols = st.columns(3)
102
  for idx, res in enumerate(results):
103
  with cols[idx % 3]:
104
+ if res["image"] is not None:
105
+ st.image(res["image"], width=200)
106
+ st.caption(f"Score: {res['score']:.3f}")
 
 
 
107
  else:
108
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
109
  else:
110
+ st.info("No results found.")
111
  else:
112
+ st.error(f"Error: {resp.status_code} - {resp.text}")
113
 
114
  with tab_text2text:
115
  text_input_tt = st.text_input("Enter text to find similar captions",
 
128
  else:
129
  st.info("No similar captions found.")
130
  else:
131
+ st.error(f"Error: {resp.status_code} - {resp.text}")
132
+
133
+
134
+ # Old Code
135
+
136
+ # # app.py
137
+ # import streamlit as st
138
+ # import requests
139
+ # import subprocess
140
+ # import time
141
+ # from PIL import Image
142
+ # import io
143
+ # import base64 # For displaying retrieved images if needed
144
+ # import socket
145
+
146
+ # # Start FastAPI server in background
147
+ # # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
148
+ # # time.sleep(2) # Wait for server to start
149
+
150
+ # # Check if port is free
151
+ # def is_port_free(port):
152
+ # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
153
+ # return s.connect_ex(('localhost', port)) != 0
154
+
155
+ # if is_port_free(8001):
156
+ # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
157
+ # else:
158
+ # print("Port 8001 in use - skipping backend startup")
159
+ # time.sleep(5) # longer wait
160
+
161
+ # API_BASE = "http://localhost:8001"
162
+
163
+ # st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide")
164
+
165
+ # st.title("Multimodal Retrieval & Captioning System")
166
+
167
+ # # Model selection (add more later)
168
+ # model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0)
169
+
170
+ # # Common inputs
171
+ # input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"])
172
+ # image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image")
173
+ # text_input = st.text_input("Text Input")
174
+ # top_k = st.sidebar.slider("Top K", 1, 10, 5)
175
+
176
+ # # Tabs for tasks
177
+ # tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([
178
+ # "Image → Caption",
179
+ # "Text → Image",
180
+ # "Image → Text",
181
+ # "Image → Image",
182
+ # "Text → Text"
183
+ # ])
184
+
185
+ # with tab_caption:
186
+ # if image_file and st.button("Generate Caption"):
187
+ # files = {"file": image_file.getvalue()}
188
+ # data = {"model_name": model_name}
189
+ # resp = requests.post(f"{API_BASE}/caption", files=files, data=data)
190
+ # if resp.status_code == 200:
191
+ # st.write("Caption:", resp.json()["caption"])
192
+ # else:
193
+ # st.error("Error: " + resp.text)
194
+
195
+ # # with tab_text2img:
196
+ # # if text_input and st.button("Search Images"):
197
+ # # data = {"model_name": model_name, "query": text_input, "top_k": top_k}
198
+ # # resp = requests.post(f"{API_BASE}/search/text2img", data=data)
199
+ # # if resp.status_code == 200:
200
+ # # results = resp.json()
201
+ # # for res in results:
202
+ # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
203
+ # # else:
204
+ # # st.error("Error: " + resp.text)
205
+
206
+ # with tab_text2img:
207
+ # if text_input and st.button("Search Images"):
208
+ # data = {"model_name": model_name, "query": text_input, "top_k": top_k}
209
+ # resp = requests.post(f"{API_BASE}/search/text2img", data=data)
210
+ # if resp.status_code == 200:
211
+ # results = resp.json()
212
+ # if results:
213
+ # st.subheader("Retrieved Images")
214
+ # cols = st.columns(3)
215
+ # for idx, res in enumerate(results):
216
+ # with cols[idx % 3]:
217
+ # if res["image"] is not None:
218
+ # st.image(res["image"], width=200)
219
+ # st.caption(f"Score: {res['score']:.3f}")
220
+ # if "caption" in res: # if you add caption to results later
221
+ # st.write(res["caption"])
222
+ # else:
223
+ # st.caption(f"Score: {res['score']:.3f} (Image not found)")
224
+ # else:
225
+ # st.info("No results found.")
226
+ # else:
227
+ # st.error(f"Error: {resp.status_code} - {resp.text}")
228
+
229
+ # with tab_img2text:
230
+ # if image_file and st.button("Retrieve Text"):
231
+ # files = {"file": image_file.getvalue()}
232
+ # data = {"model_name": model_name, "top_k": top_k}
233
+ # resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
234
+ # if resp.status_code == 200:
235
+ # st.write("Retrieved Texts:", resp.json())
236
+ # else:
237
+ # st.error("Error: " + resp.text)
238
+
239
+ # # with tab_img2img:
240
+ # # if image_file and st.button("Retrieve Similar Images"):
241
+ # # files = {"file": image_file.getvalue()}
242
+ # # data = {"model_name": model_name, "top_k": top_k}
243
+ # # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
244
+ # # if resp.status_code == 200:
245
+ # # results = resp.json()
246
+ # # for res in results:
247
+ # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
248
+ # # else:
249
+ # # st.error("Error: " + resp.text)
250
+
251
+ # with tab_img2img:
252
+ # if image_file and st.button("Retrieve Similar Images"):
253
+ # files = {"file": image_file.getvalue()}
254
+ # data = {"model_name": model_name, "top_k": top_k}
255
+ # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
256
+
257
+ # if resp.status_code == 200:
258
+ # results = resp.json()
259
+ # if results:
260
+ # st.subheader("Retrieved Similar Images")
261
+ # cols = st.columns(3)
262
+ # for idx, res in enumerate(results):
263
+ # with cols[idx % 3]:
264
+ # if "image" in res and res["image"] is not None:
265
+ # st.image(
266
+ # res["image"],
267
+ # width=200, # recommended instead of use_column_width
268
+ # caption=f"Score: {res['score']:.3f}"
269
+ # )
270
+ # else:
271
+ # st.caption(f"Score: {res['score']:.3f} (Image not available)")
272
+ # else:
273
+ # st.info("No similar images found in the dataset.")
274
+ # else:
275
+ # st.error(f"Backend error: {resp.status_code} - {resp.text}")
276
+
277
+ # with tab_text2text:
278
+ # text_input_tt = st.text_input("Enter text to find similar captions",
279
+ # placeholder="A child playing with water in the garden")
280
+
281
+ # if text_input_tt and st.button("Search Similar Captions"):
282
+ # data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k}
283
+ # resp = requests.post(f"{API_BASE}/search/text2text", data=data)
284
+
285
+ # if resp.status_code == 200:
286
+ # results = resp.json()
287
+ # if results:
288
+ # st.subheader("Top similar captions:")
289
+ # for idx, res in enumerate(results, 1):
290
+ # st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`")
291
+ # else:
292
+ # st.info("No similar captions found.")
293
+ # else:
294
+ # st.error(f"Error: {resp.status_code} - {resp.text}")
models/resnet_lstm_attention/model.py CHANGED
@@ -19,7 +19,6 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
19
  self.retrieval_service = None
20
  self.device = torch.device("cpu")
21
  self.dataset = None
22
- #self.model_repo = "skodan/resnet-lstm-attention-weights"
23
 
24
  def load(self) -> None:
25
  if self.caption_bundle is not None and self.retrieval_service is not None:
@@ -28,14 +27,14 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
28
  MODEL_REPO = "skodan/resnet-lstm-attention-weights"
29
 
30
  files_to_download = [
31
- "caption_model.pth",
32
- "flickr8k_retrieval_model.pth",
33
- "image_embeddings.faiss",
34
- "text_embeddings.faiss",
35
- "image_id_map.pkl",
36
- "text_id_map.pkl",
37
- "vocab.pkl"
38
- ]
39
 
40
  downloaded_paths = {}
41
  for fname in files_to_download:
@@ -43,33 +42,33 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
43
  path = hf_hub_download(
44
  repo_id=MODEL_REPO,
45
  filename=fname,
46
- repo_type="model",
47
  )
48
  downloaded_paths[fname] = path
49
  except Exception as e:
50
  raise RuntimeError(f"Failed to download {fname} from {MODEL_REPO}: {e}")
51
 
52
- # Download large files from HF Hub
53
  caption_pth = downloaded_paths["caption_model.pth"]
54
  retrieval_pth = downloaded_paths["flickr8k_retrieval_model.pth"]
55
  image_index_faiss = downloaded_paths["image_embeddings.faiss"]
56
  text_index_faiss = downloaded_paths["text_embeddings.faiss"]
57
  image_map_pkl = downloaded_paths["image_id_map.pkl"]
58
  text_map_pkl = downloaded_paths["text_id_map.pkl"]
59
- vocab_pkl = downloaded_paths["vocab.pkl"]
60
 
61
- # Load configs (assume small, committed to repo)
62
  base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # go up to project root
63
  config_path = os.path.join(base_dir, "configs", "caption_config.json")
64
  preprocess_cfg_path = os.path.join(base_dir, "configs", "preprocess_config.json")
65
 
 
 
 
66
  with open(config_path, "r") as f:
67
  caption_config = json.load(f)
68
 
69
  with open(preprocess_cfg_path, "r") as f:
70
  preprocess_cfg = json.load(f)
71
 
72
- # Load captioning
73
  self.caption_bundle = load_captioning_model(
74
  model_path=caption_pth,
75
  vocab_path=vocab_pkl,
@@ -77,7 +76,6 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
77
  device=self.device
78
  )
79
 
80
- # Load retrieval
81
  clip_model = load_clip_model(
82
  model_path=retrieval_pth,
83
  vocab=self.caption_bundle["vocab"],
@@ -94,19 +92,18 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
94
  )
95
 
96
  if self.dataset is None:
97
- print("Loading Flickr8k test split from Hugging Face...")
98
  ds = load_dataset("jxie/flickr8k")
99
- self.dataset = concatenate_datasets([
100
- ds["train"],
101
- ds["validation"],
102
- ds["test"]
103
- ])
104
  print(f"Loaded {len(self.dataset)} images/captions from full dataset.")
105
 
106
  print("Model components loaded successfully.")
107
 
108
  @torch.no_grad()
109
  def generate_caption(self, image: Image.Image) -> str:
 
 
 
110
  encoder = self.caption_bundle["encoder"]
111
  decoder = self.caption_bundle["decoder"]
112
  vocab = self.caption_bundle["vocab"]
@@ -115,6 +112,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
115
  transform = self.caption_bundle["transform"]
116
 
117
  image_tensor = transform(image).unsqueeze(0).to(self.device)
 
118
  features = encoder(image_tensor)
119
  tokens = decoder.generate(
120
  features,
@@ -124,125 +122,303 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
124
  )
125
  return " ".join(tokens)
126
 
127
- # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
128
- # return self.retrieval_service.text_to_image(text, top_k)
129
-
130
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
131
  raw_results = self.retrieval_service.text_to_image(text, top_k)
132
- return self._format_retrieval_results(raw_results)
133
-
134
- # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
135
- # results = self.retrieval_service.text_to_image(text, top_k)
136
 
137
- # formatted_results = []
138
- # for res in results:
139
- # img_id_str = str(res["image_path"]) # this is likely the ID or filename without .jpg
140
- # img_file = f"{img_id_str}.jpg"
141
- # #img_file = f"{img_id}.jpg" if not img_id.endswith('.jpg') else img_id
142
-
143
- # full_path = os.path.join("flickr8k_images", img_file)
144
 
145
- # # Only include if the file actually exists in the demo folder
146
- # if os.path.exists(full_path):
147
- # formatted_results.append({
148
- # "image_path": full_path,
149
- # "score": res["score"]
150
- # })
151
- # # Optional: skip or use placeholder if missing
152
- # else:
153
- # formatted_results.append({
154
- # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
155
- # "score": res["score"]
156
- # })
157
 
158
- # return formatted_results
159
-
160
- # def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
161
- # return self.retrieval_service.image_to_text(image, top_k)
162
 
163
  def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
164
  return self.retrieval_service.image_to_text(image, top_k)
165
 
166
- # def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
167
- # results = self.retrieval_service.image_to_text(image, top_k) # assuming this returns list of dicts
168
-
169
- # formatted_results = []
170
- # for res in results:
171
- # img_id = res["image_path"] # same as above
172
- # img_file = f"{img_id}.jpg"# if not img_id.endswith('.jpg') else img_id
173
-
174
- # full_path = os.path.join("flickr8k_images", img_file)
175
-
176
- # if os.path.exists(full_path):
177
- # formatted_results.append({
178
- # "image_path": full_path,
179
- # "score": res["score"]
180
- # })
181
- # else:
182
- # # Optional fallback so UI doesn't crash
183
- # formatted.append({
184
- # "image_path": "https://via.placeholder.com/300x200?text=Not+in+demo",
185
- # "score": float(res["score"])
186
- # })
187
-
188
- # return formatted_results
189
-
190
  def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
191
- raw_results = self.retrieval_service.image_to_image(image, top_k) # new call
192
- return self._format_retrieval_results(raw_results)
193
-
194
-
195
- def _format_retrieval_results(self, raw_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
196
  formatted = []
197
  for res in raw_results:
198
- img_id = res["image_path"] # this is integer ID
199
- img_filename = f"{img_id}.jpg" # always append .jpg, no .endswith needed
200
- full_path = os.path.join("flickr8k_images", img_filename)
201
 
202
- if os.path.exists(full_path):
 
203
  formatted.append({
204
- "image_path": full_path,
205
- "score": res["score"]
206
  })
207
- else:
208
  formatted.append({
209
- "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
210
- "score": res["score"]
211
  })
 
212
  return formatted
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
216
- # # image_tensor = self.retrieval_service.image_transform(image).unsqueeze(0).to(self.device)
217
- # # with torch.no_grad():
218
- # # emb = self.retrieval_service.clip_model.encode_image(image_tensor).cpu().numpy()
219
- # # emb = self.retrieval_service._normalize(emb)
220
- # # scores, idxs = self.retrieval_service.image_index.search(emb, top_k)
221
- # # return [
222
- # # {"image_path": self.retrieval_service.image_id_map[i], "score": float(scores[0][j])}
223
- # # for j, i in enumerate(idxs[0])
224
- # # ]
225
- # raw_results = self.retrieval_service.image_to_image(image, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- # formatted = []
228
- # for res in raw_results:
229
- # img_id_str = str(res["image_path"])
 
230
 
231
- # img_filename = f"{img_id_str}.jpg"
232
- # full_path = os.path.join("flickr8k_images", img_filename)
233
 
234
- # if os.path.exists(full_path):
235
- # formatted.append({
236
- # "image_path": full_path,
237
- # "score": float(res["score"])
238
- # })
239
- # else:
240
- # formatted.append({
241
- # "image_path": "https://via.placeholder.com/300x200?text=Not+in+demo",
242
- # "score": float(res["score"])
243
- # })
 
244
 
245
- # return formatted
246
 
247
- def text_to_text(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
248
- return self.retrieval_service.text_to_text(text, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  self.retrieval_service = None
20
  self.device = torch.device("cpu")
21
  self.dataset = None
 
22
 
23
  def load(self) -> None:
24
  if self.caption_bundle is not None and self.retrieval_service is not None:
 
27
  MODEL_REPO = "skodan/resnet-lstm-attention-weights"
28
 
29
  files_to_download = [
30
+ "caption_model.pth",
31
+ "flickr8k_retrieval_model.pth",
32
+ "image_embeddings.faiss",
33
+ "text_embeddings.faiss",
34
+ "image_id_map.pkl",
35
+ "text_id_map.pkl",
36
+ "vocab.pkl" # only if large; otherwise commit it
37
+ ]
38
 
39
  downloaded_paths = {}
40
  for fname in files_to_download:
 
42
  path = hf_hub_download(
43
  repo_id=MODEL_REPO,
44
  filename=fname,
45
+ repo_type="model"
46
  )
47
  downloaded_paths[fname] = path
48
  except Exception as e:
49
  raise RuntimeError(f"Failed to download {fname} from {MODEL_REPO}: {e}")
50
 
 
51
  caption_pth = downloaded_paths["caption_model.pth"]
52
  retrieval_pth = downloaded_paths["flickr8k_retrieval_model.pth"]
53
  image_index_faiss = downloaded_paths["image_embeddings.faiss"]
54
  text_index_faiss = downloaded_paths["text_embeddings.faiss"]
55
  image_map_pkl = downloaded_paths["image_id_map.pkl"]
56
  text_map_pkl = downloaded_paths["text_id_map.pkl"]
57
+ vocab_pkl = downloaded_paths["vocab.pkl"]
58
 
 
59
  base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # go up to project root
60
  config_path = os.path.join(base_dir, "configs", "caption_config.json")
61
  preprocess_cfg_path = os.path.join(base_dir, "configs", "preprocess_config.json")
62
 
63
+ if not os.path.exists(config_path):
64
+ raise FileNotFoundError(f"Config not found: {config_path}")
65
+
66
  with open(config_path, "r") as f:
67
  caption_config = json.load(f)
68
 
69
  with open(preprocess_cfg_path, "r") as f:
70
  preprocess_cfg = json.load(f)
71
 
 
72
  self.caption_bundle = load_captioning_model(
73
  model_path=caption_pth,
74
  vocab_path=vocab_pkl,
 
76
  device=self.device
77
  )
78
 
 
79
  clip_model = load_clip_model(
80
  model_path=retrieval_pth,
81
  vocab=self.caption_bundle["vocab"],
 
92
  )
93
 
94
  if self.dataset is None:
95
+ print("Loading full Flickr8k dataset from Hugging Face...")
96
  ds = load_dataset("jxie/flickr8k")
97
+ self.dataset = concatenate_datasets([ds["train"], ds["validation"], ds["test"]])
 
 
 
 
98
  print(f"Loaded {len(self.dataset)} images/captions from full dataset.")
99
 
100
  print("Model components loaded successfully.")
101
 
102
  @torch.no_grad()
103
  def generate_caption(self, image: Image.Image) -> str:
104
+ if self.caption_bundle is None:
105
+ raise RuntimeError("Model not loaded. Call load() first.")
106
+
107
  encoder = self.caption_bundle["encoder"]
108
  decoder = self.caption_bundle["decoder"]
109
  vocab = self.caption_bundle["vocab"]
 
112
  transform = self.caption_bundle["transform"]
113
 
114
  image_tensor = transform(image).unsqueeze(0).to(self.device)
115
+
116
  features = encoder(image_tensor)
117
  tokens = decoder.generate(
118
  features,
 
122
  )
123
  return " ".join(tokens)
124
 
 
 
 
125
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
126
  raw_results = self.retrieval_service.text_to_image(text, top_k)
 
 
 
 
127
 
128
+ formatted = []
129
+ for res in raw_results:
130
+ idx = int(res["image_path"])
 
 
 
 
131
 
132
+ try:
133
+ pil_img = self.dataset[idx]["image"]
134
+ formatted.append({
135
+ "image": pil_img,
136
+ "score": float(res["score"])
137
+ })
138
+ except (IndexError, KeyError):
139
+ formatted.append({
140
+ "image": None,
141
+ "score": float(res["score"])
142
+ })
 
143
 
144
+ return formatted
 
 
 
145
 
146
  def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
147
  return self.retrieval_service.image_to_text(image, top_k)
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
150
+ raw_results = self.retrieval_service.image_to_image(image, top_k)
151
+
 
 
 
152
  formatted = []
153
  for res in raw_results:
154
+ idx = int(res["image_path"])
 
 
155
 
156
+ try:
157
+ pil_img = self.dataset[idx]["image"]
158
  formatted.append({
159
+ "image": pil_img,
160
+ "score": float(res["score"])
161
  })
162
+ except (IndexError, KeyError):
163
  formatted.append({
164
+ "image": None,
165
+ "score": float(res["score"])
166
  })
167
+
168
  return formatted
169
 
170
+ def text_to_text(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
171
+ return self.retrieval_service.text_to_text(text, top_k)
172
+
173
+
174
+
175
+ # Old code
176
+
177
+ # import os
178
+ # import json
179
+ # import torch
180
+ # from huggingface_hub import hf_hub_download
181
+ # from PIL import Image
182
+ # import numpy as np
183
+ # from typing import List, Dict, Any
184
+ # from datasets import load_dataset, concatenate_datasets
185
+
186
+ # from models.resnet_lstm_attention.loader import load_captioning_model
187
+ # from models.resnet_lstm_attention.retrieval import RetrievalService
188
+ # from models.resnet_lstm_attention.clip_loader import load_clip_model
189
+ # from models.resnet_lstm_attention.captioning import CaptioningService # Not directly used, but for reference
190
+ # from utils.interfaces import UnifiedModelInterface # Adjust path if needed
191
+
192
+ # class ResNetLSTMAttentionModel(UnifiedModelInterface):
193
+ # def __init__(self):
194
+ # self.caption_bundle = None
195
+ # self.retrieval_service = None
196
+ # self.device = torch.device("cpu")
197
+ # self.dataset = None
198
+ # #self.model_repo = "skodan/resnet-lstm-attention-weights"
199
+
200
+ # def load(self) -> None:
201
+ # if self.caption_bundle is not None and self.retrieval_service is not None:
202
+ # return
203
+
204
+ # MODEL_REPO = "skodan/resnet-lstm-attention-weights"
205
 
206
+ # files_to_download = [
207
+ # "caption_model.pth",
208
+ # "flickr8k_retrieval_model.pth",
209
+ # "image_embeddings.faiss",
210
+ # "text_embeddings.faiss",
211
+ # "image_id_map.pkl",
212
+ # "text_id_map.pkl",
213
+ # "vocab.pkl"
214
+ # ]
215
+
216
+ # downloaded_paths = {}
217
+ # for fname in files_to_download:
218
+ # try:
219
+ # path = hf_hub_download(
220
+ # repo_id=MODEL_REPO,
221
+ # filename=fname,
222
+ # repo_type="model",
223
+ # )
224
+ # downloaded_paths[fname] = path
225
+ # except Exception as e:
226
+ # raise RuntimeError(f"Failed to download {fname} from {MODEL_REPO}: {e}")
227
+
228
+ # # Download large files from HF Hub
229
+ # caption_pth = downloaded_paths["caption_model.pth"]
230
+ # retrieval_pth = downloaded_paths["flickr8k_retrieval_model.pth"]
231
+ # image_index_faiss = downloaded_paths["image_embeddings.faiss"]
232
+ # text_index_faiss = downloaded_paths["text_embeddings.faiss"]
233
+ # image_map_pkl = downloaded_paths["image_id_map.pkl"]
234
+ # text_map_pkl = downloaded_paths["text_id_map.pkl"]
235
+ # vocab_pkl = downloaded_paths["vocab.pkl"]
236
+
237
+ # # Load configs (assume small, committed to repo)
238
+ # base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # go up to project root
239
+ # config_path = os.path.join(base_dir, "configs", "caption_config.json")
240
+ # preprocess_cfg_path = os.path.join(base_dir, "configs", "preprocess_config.json")
241
+
242
+ # with open(config_path, "r") as f:
243
+ # caption_config = json.load(f)
244
+
245
+ # with open(preprocess_cfg_path, "r") as f:
246
+ # preprocess_cfg = json.load(f)
247
+
248
+ # # Load captioning
249
+ # self.caption_bundle = load_captioning_model(
250
+ # model_path=caption_pth,
251
+ # vocab_path=vocab_pkl,
252
+ # config_path=config_path,
253
+ # device=self.device
254
+ # )
255
+
256
+ # # Load retrieval
257
+ # clip_model = load_clip_model(
258
+ # model_path=retrieval_pth,
259
+ # vocab=self.caption_bundle["vocab"],
260
+ # device=self.device
261
+ # )
262
+
263
+ # self.retrieval_service = RetrievalService(
264
+ # clip_model=clip_model,
265
+ # image_index_path=image_index_faiss,
266
+ # text_index_path=text_index_faiss,
267
+ # image_map_path=image_map_pkl,
268
+ # text_map_path=text_map_pkl,
269
+ # preprocess=preprocess_cfg
270
+ # )
271
+
272
+ # if self.dataset is None:
273
+ # print("Loading Flickr8k test split from Hugging Face...")
274
+ # ds = load_dataset("jxie/flickr8k")
275
+ # self.dataset = concatenate_datasets([
276
+ # ds["train"],
277
+ # ds["validation"],
278
+ # ds["test"]
279
+ # ])
280
+ # print(f"Loaded {len(self.dataset)} images/captions from full dataset.")
281
+
282
+ # print("Model components loaded successfully.")
283
+
284
+ # @torch.no_grad()
285
+ # def generate_caption(self, image: Image.Image) -> str:
286
+ # encoder = self.caption_bundle["encoder"]
287
+ # decoder = self.caption_bundle["decoder"]
288
+ # vocab = self.caption_bundle["vocab"]
289
+ # inv_vocab = self.caption_bundle["inv_vocab"]
290
+ # max_len = self.caption_bundle["max_len"]
291
+ # transform = self.caption_bundle["transform"]
292
+
293
+ # image_tensor = transform(image).unsqueeze(0).to(self.device)
294
+ # features = encoder(image_tensor)
295
+ # tokens = decoder.generate(
296
+ # features,
297
+ # vocab=vocab,
298
+ # inv_vocab=inv_vocab,
299
+ # max_len=max_len
300
+ # )
301
+ # return " ".join(tokens)
302
+
303
+ # # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
304
+ # # return self.retrieval_service.text_to_image(text, top_k)
305
+
306
+ # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
307
+ # raw_results = self.retrieval_service.text_to_image(text, top_k)
308
+ # return self._format_retrieval_results(raw_results)
309
+
310
+ # # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
311
+ # # results = self.retrieval_service.text_to_image(text, top_k)
312
+
313
+ # # formatted_results = []
314
+ # # for res in results:
315
+ # # img_id_str = str(res["image_path"]) # this is likely the ID or filename without .jpg
316
+ # # img_file = f"{img_id_str}.jpg"
317
+ # # #img_file = f"{img_id}.jpg" if not img_id.endswith('.jpg') else img_id
318
+
319
+ # # full_path = os.path.join("flickr8k_images", img_file)
320
+
321
+ # # # Only include if the file actually exists in the demo folder
322
+ # # if os.path.exists(full_path):
323
+ # # formatted_results.append({
324
+ # # "image_path": full_path,
325
+ # # "score": res["score"]
326
+ # # })
327
+ # # # Optional: skip or use placeholder if missing
328
+ # # else:
329
+ # # formatted_results.append({
330
+ # # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
331
+ # # "score": res["score"]
332
+ # # })
333
+
334
+ # # return formatted_results
335
+
336
+ # # def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
337
+ # # return self.retrieval_service.image_to_text(image, top_k)
338
+
339
+ # def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
340
+ # return self.retrieval_service.image_to_text(image, top_k)
341
+
342
+ # # def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
343
+ # # results = self.retrieval_service.image_to_text(image, top_k) # assuming this returns list of dicts
344
 
345
+ # # formatted_results = []
346
+ # # for res in results:
347
+ # # img_id = res["image_path"] # same as above
348
+ # # img_file = f"{img_id}.jpg"# if not img_id.endswith('.jpg') else img_id
349
 
350
+ # # full_path = os.path.join("flickr8k_images", img_file)
 
351
 
352
+ # # if os.path.exists(full_path):
353
+ # # formatted_results.append({
354
+ # # "image_path": full_path,
355
+ # # "score": res["score"]
356
+ # # })
357
+ # # else:
358
+ # # # Optional fallback so UI doesn't crash
359
+ # # formatted.append({
360
+ # # "image_path": "https://via.placeholder.com/300x200?text=Not+in+demo",
361
+ # # "score": float(res["score"])
362
+ # # })
363
 
364
+ # # return formatted_results
365
 
366
+ # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
367
+ # raw_results = self.retrieval_service.image_to_image(image, top_k) # new call
368
+ # return self._format_retrieval_results(raw_results)
369
+
370
+
371
+ # def _format_retrieval_results(self, raw_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
372
+ # formatted = []
373
+ # for res in raw_results:
374
+ # img_id = res["image_path"] # this is integer ID
375
+ # img_filename = f"{img_id}.jpg" # always append .jpg, no .endswith needed
376
+ # full_path = os.path.join("flickr8k_images", img_filename)
377
+
378
+ # if os.path.exists(full_path):
379
+ # formatted.append({
380
+ # "image_path": full_path,
381
+ # "score": res["score"]
382
+ # })
383
+ # else:
384
+ # formatted.append({
385
+ # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
386
+ # "score": res["score"]
387
+ # })
388
+ # return formatted
389
+
390
+
391
+ # # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
392
+ # # # image_tensor = self.retrieval_service.image_transform(image).unsqueeze(0).to(self.device)
393
+ # # # with torch.no_grad():
394
+ # # # emb = self.retrieval_service.clip_model.encode_image(image_tensor).cpu().numpy()
395
+ # # # emb = self.retrieval_service._normalize(emb)
396
+ # # # scores, idxs = self.retrieval_service.image_index.search(emb, top_k)
397
+ # # # return [
398
+ # # # {"image_path": self.retrieval_service.image_id_map[i], "score": float(scores[0][j])}
399
+ # # # for j, i in enumerate(idxs[0])
400
+ # # # ]
401
+ # # raw_results = self.retrieval_service.image_to_image(image, top_k)
402
+
403
+ # # formatted = []
404
+ # # for res in raw_results:
405
+ # # img_id_str = str(res["image_path"])
406
+
407
+ # # img_filename = f"{img_id_str}.jpg"
408
+ # # full_path = os.path.join("flickr8k_images", img_filename)
409
+
410
+ # # if os.path.exists(full_path):
411
+ # # formatted.append({
412
+ # # "image_path": full_path,
413
+ # # "score": float(res["score"])
414
+ # # })
415
+ # # else:
416
+ # # formatted.append({
417
+ # # "image_path": "https://via.placeholder.com/300x200?text=Not+in+demo",
418
+ # # "score": float(res["score"])
419
+ # # })
420
+
421
+ # # return formatted
422
+
423
+ # def text_to_text(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
424
+ # return self.retrieval_service.text_to_text(text, top_k)
models/resnet_lstm_attention/retrieval.py CHANGED
@@ -2,7 +2,6 @@ import faiss
2
  import pickle
3
  import torch
4
  import numpy as np
5
- import os
6
  from PIL import Image
7
  from torchvision import transforms
8
  from typing import List, Dict, Any
@@ -34,70 +33,23 @@ class RetrievalService:
34
 
35
  def _normalize(self, x):
36
  return x / np.linalg.norm(x, axis=1, keepdims=True)
37
-
38
 
39
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
40
- raw_results = self.retrieval_service.text_to_image(text, top_k)
41
-
42
- formatted = []
43
- for res in raw_results:
44
- idx = int(res["image_path"]) # the FAISS index (integer)
45
-
46
- try:
47
- pil_img = self.dataset[idx]["image"] # directly get PIL.Image
48
- formatted.append({
49
- "image": pil_img, # ← pass PIL.Image to UI
50
- "score": float(res["score"])
51
- })
52
- except (IndexError, KeyError):
53
- formatted.append({
54
- "image": None,
55
- "score": float(res["score"])
56
- })
57
-
58
- return formatted
59
-
60
-
61
- # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
62
- # raw_results = self.retrieval_service.text_to_image(text, top_k)
63
-
64
- # formatted = []
65
- # for res in raw_results:
66
- # img_id = res["image_path"] # int or str
67
- # img_id_str = str(img_id)
68
- # img_filename = f"{img_id_str}.jpg" # always append .jpg, no .endswith
69
- # full_path = os.path.join("flickr8k_images", img_filename)
70
-
71
- # if os.path.exists(full_path):
72
- # formatted.append({
73
- # "image_path": full_path,
74
- # "score": float(res["score"])
75
- # })
76
- # else:
77
- # formatted.append({
78
- # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
79
- # "score": float(res["score"])
80
- # })
81
-
82
- # return formatted
83
-
84
-
85
- # def text_to_image(self, text, top_k=5):
86
- # with torch.no_grad():
87
- # emb = self.clip_model.encode_text(text).cpu().numpy()
88
- # emb = self._normalize(emb)
89
 
90
- # scores, idxs = self.image_index.search(emb, top_k)
91
- # return [
92
- # {
93
- # "image_path": self.image_id_map[i],
94
- # "score": float(scores[0][j])
95
- # }
96
- # for j, i in enumerate(idxs[0])
97
- # ]
98
 
99
  def image_to_text(self, image: Image.Image, top_k=5):
100
- image = self.image_transform(image).unsqueeze(0)
101
  with torch.no_grad():
102
  emb = self.clip_model.encode_image(image).cpu().numpy()
103
  emb = self._normalize(emb)
@@ -107,7 +59,6 @@ class RetrievalService:
107
  print(f"DEBUG: Returning results: {results}")
108
  return results
109
 
110
-
111
  def text_to_text(self, text: str, top_k: int = 5):
112
  with torch.no_grad():
113
  emb = self.clip_model.encode_text(text).cpu().numpy()
@@ -125,67 +76,213 @@ class RetrievalService:
125
 
126
  print(f"DEBUG: Text-to-text results: {results}")
127
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
- # def image_to_image(self, image: Image.Image, top_k=5):
131
- # """
132
- # Image → Image retrieval: encode input image, search image index, return image IDs and scores.
133
- # """
134
- # image = self.image_transform(image).unsqueeze(0).to(self.device)
135
- # with torch.no_grad():
136
- # emb = self.clip_model.encode_image(image).cpu().numpy()
137
- # emb = self._normalize(emb)
138
-
139
- # scores, idxs = self.image_index.search(emb, top_k)
140
- # return [
141
- # {
142
- # "image_path": self.image_id_map[i], # integer ID
143
- # "score": float(scores[0][j])
144
- # }
145
- # for j, i in enumerate(idxs[0])
146
- # ]
147
 
148
- # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
149
- # raw_results = self.retrieval_service.image_to_image(image, top_k) # now exists
150
- # # ... same logic as above ...
151
 
152
- # formatted = []
153
- # for res in raw_results:
154
- # img_id = res["image_path"]
155
- # img_id_str = str(img_id)
156
- # img_filename = f"{img_id_str}.jpg"
157
- # full_path = os.path.join("flickr8k_images", img_filename)
158
 
159
- # if os.path.exists(full_path):
160
- # formatted.append({
161
- # "image_path": full_path,
162
- # "score": float(res["score"])
163
- # })
164
- # else:
165
- # formatted.append({
166
- # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
167
- # "score": float(res["score"])
168
- # })
169
 
170
- # return formatted
171
 
172
- def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
173
- raw_results = self.retrieval_service.image_to_image(image, top_k)
174
 
175
- formatted = []
176
- for res in raw_results:
177
- idx = int(res["image_path"])
178
 
179
- try:
180
- pil_img = self.dataset[idx]["image"]
181
- formatted.append({
182
- "image": pil_img,
183
- "score": float(res["score"])
184
- })
185
- except (IndexError, KeyError):
186
- formatted.append({
187
- "image": None,
188
- "score": float(res["score"])
189
- })
190
 
191
- return formatted
 
2
  import pickle
3
  import torch
4
  import numpy as np
 
5
  from PIL import Image
6
  from torchvision import transforms
7
  from typing import List, Dict, Any
 
33
 
34
  def _normalize(self, x):
35
  return x / np.linalg.norm(x, axis=1, keepdims=True)
 
36
 
37
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
38
+ with torch.no_grad():
39
+ emb = self.clip_model.encode_text(text).cpu().numpy()
40
+ emb = self._normalize(emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ scores, idxs = self.image_index.search(emb, top_k)
43
+ return [
44
+ {
45
+ "image_path": self.image_id_map[i], # integer ID
46
+ "score": float(scores[0][j])
47
+ }
48
+ for j, i in enumerate(idxs[0])
49
+ ]
50
 
51
  def image_to_text(self, image: Image.Image, top_k=5):
52
+ image = self.image_transform(image).unsqueeze(0).to(self.device)
53
  with torch.no_grad():
54
  emb = self.clip_model.encode_image(image).cpu().numpy()
55
  emb = self._normalize(emb)
 
59
  print(f"DEBUG: Returning results: {results}")
60
  return results
61
 
 
62
  def text_to_text(self, text: str, top_k: int = 5):
63
  with torch.no_grad():
64
  emb = self.clip_model.encode_text(text).cpu().numpy()
 
76
 
77
  print(f"DEBUG: Text-to-text results: {results}")
78
  return results
79
+
80
+ def image_to_image(self, image: Image.Image, top_k: int = 5):
81
+ image = self.image_transform(image).unsqueeze(0).to(self.device)
82
+ with torch.no_grad():
83
+ emb = self.clip_model.encode_image(image).cpu().numpy()
84
+ emb = self._normalize(emb)
85
+
86
+ scores, idxs = self.image_index.search(emb, top_k)
87
+ return [
88
+ {
89
+ "image_path": self.image_id_map[i], # integer ID
90
+ "score": float(scores[0][j])
91
+ }
92
+ for j, i in enumerate(idxs[0])
93
+ ]
94
+
95
+
96
+ # Old Code
97
+
98
+ # import faiss
99
+ # import pickle
100
+ # import torch
101
+ # import numpy as np
102
+ # import os
103
+ # from PIL import Image
104
+ # from torchvision import transforms
105
+ # from typing import List, Dict, Any
106
+
107
+ # class RetrievalService:
108
+ # def __init__(self, clip_model, image_index_path, text_index_path,
109
+ # image_map_path, text_map_path, preprocess):
110
+
111
+ # self.device = torch.device("cpu")
112
+ # self.clip_model = clip_model
113
+
114
+ # self.image_index = faiss.read_index(image_index_path)
115
+ # self.text_index = faiss.read_index(text_index_path)
116
+
117
+ # with open(image_map_path, "rb") as f:
118
+ # self.image_id_map = pickle.load(f)
119
+
120
+ # with open(text_map_path, "rb") as f:
121
+ # self.text_id_map = pickle.load(f)
122
+
123
+ # self.image_transform = transforms.Compose([
124
+ # transforms.Resize((224, 224)),
125
+ # transforms.ToTensor(),
126
+ # transforms.Normalize(
127
+ # mean=preprocess["mean"],
128
+ # std=preprocess["std"]
129
+ # )
130
+ # ])
131
+
132
+ # def _normalize(self, x):
133
+ # return x / np.linalg.norm(x, axis=1, keepdims=True)
134
+
135
+
136
+ # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
137
+ # raw_results = self.retrieval_service.text_to_image(text, top_k)
138
+
139
+ # formatted = []
140
+ # for res in raw_results:
141
+ # idx = int(res["image_path"]) # the FAISS index (integer)
142
+
143
+ # try:
144
+ # pil_img = self.dataset[idx]["image"] # directly get PIL.Image
145
+ # formatted.append({
146
+ # "image": pil_img, # ← pass PIL.Image to UI
147
+ # "score": float(res["score"])
148
+ # })
149
+ # except (IndexError, KeyError):
150
+ # formatted.append({
151
+ # "image": None,
152
+ # "score": float(res["score"])
153
+ # })
154
+
155
+ # return formatted
156
+
157
+
158
+ # # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
159
+ # # raw_results = self.retrieval_service.text_to_image(text, top_k)
160
+
161
+ # # formatted = []
162
+ # # for res in raw_results:
163
+ # # img_id = res["image_path"] # int or str
164
+ # # img_id_str = str(img_id)
165
+ # # img_filename = f"{img_id_str}.jpg" # always append .jpg, no .endswith
166
+ # # full_path = os.path.join("flickr8k_images", img_filename)
167
+
168
+ # # if os.path.exists(full_path):
169
+ # # formatted.append({
170
+ # # "image_path": full_path,
171
+ # # "score": float(res["score"])
172
+ # # })
173
+ # # else:
174
+ # # formatted.append({
175
+ # # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
176
+ # # "score": float(res["score"])
177
+ # # })
178
+
179
+ # # return formatted
180
+
181
+
182
+ # # def text_to_image(self, text, top_k=5):
183
+ # # with torch.no_grad():
184
+ # # emb = self.clip_model.encode_text(text).cpu().numpy()
185
+ # # emb = self._normalize(emb)
186
+
187
+ # # scores, idxs = self.image_index.search(emb, top_k)
188
+ # # return [
189
+ # # {
190
+ # # "image_path": self.image_id_map[i],
191
+ # # "score": float(scores[0][j])
192
+ # # }
193
+ # # for j, i in enumerate(idxs[0])
194
+ # # ]
195
+
196
+ # def image_to_text(self, image: Image.Image, top_k=5):
197
+ # image = self.image_transform(image).unsqueeze(0)
198
+ # with torch.no_grad():
199
+ # emb = self.clip_model.encode_image(image).cpu().numpy()
200
+ # emb = self._normalize(emb)
201
+
202
+ # scores, idxs = self.text_index.search(emb, top_k)
203
+ # results = [self.text_id_map[i] for i in idxs[0]]
204
+ # print(f"DEBUG: Returning results: {results}")
205
+ # return results
206
+
207
+
208
+ # def text_to_text(self, text: str, top_k: int = 5):
209
+ # with torch.no_grad():
210
+ # emb = self.clip_model.encode_text(text).cpu().numpy()
211
+ # emb = self._normalize(emb)
212
+
213
+ # scores, idxs = self.text_index.search(emb, top_k)
214
+
215
+ # results = []
216
+ # for j, i in enumerate(idxs[0]):
217
+ # caption = self.text_id_map[i] # assuming text_id_map stores the actual caption string
218
+ # results.append({
219
+ # "caption": caption,
220
+ # "score": float(scores[0][j])
221
+ # })
222
+
223
+ # print(f"DEBUG: Text-to-text results: {results}")
224
+ # return results
225
 
226
 
227
+ # # def image_to_image(self, image: Image.Image, top_k=5):
228
+ # # """
229
+ # # Image → Image retrieval: encode input image, search image index, return image IDs and scores.
230
+ # # """
231
+ # # image = self.image_transform(image).unsqueeze(0).to(self.device)
232
+ # # with torch.no_grad():
233
+ # # emb = self.clip_model.encode_image(image).cpu().numpy()
234
+ # # emb = self._normalize(emb)
235
+
236
+ # # scores, idxs = self.image_index.search(emb, top_k)
237
+ # # return [
238
+ # # {
239
+ # # "image_path": self.image_id_map[i], # integer ID
240
+ # # "score": float(scores[0][j])
241
+ # # }
242
+ # # for j, i in enumerate(idxs[0])
243
+ # # ]
244
 
245
+ # # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
246
+ # # raw_results = self.retrieval_service.image_to_image(image, top_k) # now exists
247
+ # # # ... same logic as above ...
248
 
249
+ # # formatted = []
250
+ # # for res in raw_results:
251
+ # # img_id = res["image_path"]
252
+ # # img_id_str = str(img_id)
253
+ # # img_filename = f"{img_id_str}.jpg"
254
+ # # full_path = os.path.join("flickr8k_images", img_filename)
255
 
256
+ # # if os.path.exists(full_path):
257
+ # # formatted.append({
258
+ # # "image_path": full_path,
259
+ # # "score": float(res["score"])
260
+ # # })
261
+ # # else:
262
+ # # formatted.append({
263
+ # # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
264
+ # # "score": float(res["score"])
265
+ # # })
266
 
267
+ # # return formatted
268
 
269
+ # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
270
+ # raw_results = self.retrieval_service.image_to_image(image, top_k)
271
 
272
+ # formatted = []
273
+ # for res in raw_results:
274
+ # idx = int(res["image_path"])
275
 
276
+ # try:
277
+ # pil_img = self.dataset[idx]["image"]
278
+ # formatted.append({
279
+ # "image": pil_img,
280
+ # "score": float(res["score"])
281
+ # })
282
+ # except (IndexError, KeyError):
283
+ # formatted.append({
284
+ # "image": None,
285
+ # "score": float(res["score"])
286
+ # })
287
 
288
+ # return formatted