skodan commited on
Commit
aa6ef7a
·
1 Parent(s): 8fc4a6f

fixing incorrect references

Browse files
app.py CHANGED
@@ -20,7 +20,7 @@ def is_port_free(port):
20
  if is_port_free(8001):
21
  subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
22
  else:
23
- st.warning("Port 8001 already in use backend may not start. Restart Space.")
24
  time.sleep(5) # longer wait
25
 
26
  API_BASE = "http://localhost:8001"
@@ -75,20 +75,21 @@ with tab_text2img:
75
  if resp.status_code == 200:
76
  results = resp.json()
77
  if results:
 
78
  cols = st.columns(3)
79
  for idx, res in enumerate(results):
80
  with cols[idx % 3]:
81
- try:
82
- st.image(res["image_path"],
83
- caption=f"Score: {res['score']:.3f}",
84
- use_column_width=True)
85
- except Exception as e:
86
- st.warning(f"Could not load: {res['image_path']}")
87
- st.write(f"Score: {res['score']:.3f}")
88
  else:
89
- st.info("No matching images in demo set.")
90
  else:
91
- st.error(f"Backend error: {resp.status_code} - {resp.text}")
92
 
93
  with tab_img2text:
94
  if image_file and st.button("Retrieve Text"):
@@ -121,20 +122,20 @@ with tab_img2img:
121
  if resp.status_code == 200:
122
  results = resp.json()
123
  if results:
 
124
  cols = st.columns(3)
125
  for idx, res in enumerate(results):
126
  with cols[idx % 3]:
127
- try:
128
  st.image(
129
- res["image_path"],
130
- caption=f"Score: {res['score']:.3f}",
131
- use_column_width=True
132
  )
133
- except Exception as e:
134
- st.warning(f"Could not load image: {res['image_path']}")
135
- st.write(f"Score: {res['score']:.3f}")
136
  else:
137
- st.info("No similar images found in the demo set.")
138
  else:
139
  st.error(f"Backend error: {resp.status_code} - {resp.text}")
140
 
 
20
  if is_port_free(8001):
21
  subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
22
  else:
23
+ print("Port 8001 in use - skipping backend startup")
24
  time.sleep(5) # longer wait
25
 
26
  API_BASE = "http://localhost:8001"
 
75
  if resp.status_code == 200:
76
  results = resp.json()
77
  if results:
78
+ st.subheader("Retrieved Images")
79
  cols = st.columns(3)
80
  for idx, res in enumerate(results):
81
  with cols[idx % 3]:
82
+ if res["image"] is not None:
83
+ st.image(res["image"], width=200)
84
+ st.caption(f"Score: {res['score']:.3f}")
85
+ if "caption" in res: # if you add caption to results later
86
+ st.write(res["caption"])
87
+ else:
88
+ st.caption(f"Score: {res['score']:.3f} (Image not found)")
89
  else:
90
+ st.info("No results found.")
91
  else:
92
+ st.error(f"Error: {resp.status_code} - {resp.text}")
93
 
94
  with tab_img2text:
95
  if image_file and st.button("Retrieve Text"):
 
122
  if resp.status_code == 200:
123
  results = resp.json()
124
  if results:
125
+ st.subheader("Retrieved Similar Images")
126
  cols = st.columns(3)
127
  for idx, res in enumerate(results):
128
  with cols[idx % 3]:
129
+ if "image" in res and res["image"] is not None:
130
  st.image(
131
+ res["image"],
132
+ width=200, # recommended instead of use_column_width
133
+ caption=f"Score: {res['score']:.3f}"
134
  )
135
+ else:
136
+ st.caption(f"Score: {res['score']:.3f} (Image not available)")
 
137
  else:
138
+ st.info("No similar images found in the dataset.")
139
  else:
140
  st.error(f"Backend error: {resp.status_code} - {resp.text}")
141
 
models/resnet_lstm_attention/model.py CHANGED
@@ -5,6 +5,7 @@ from huggingface_hub import hf_hub_download
5
  from PIL import Image
6
  import numpy as np
7
  from typing import List, Dict, Any
 
8
 
9
  from models.resnet_lstm_attention.loader import load_captioning_model
10
  from models.resnet_lstm_attention.retrieval import RetrievalService
@@ -17,6 +18,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
17
  self.caption_bundle = None
18
  self.retrieval_service = None
19
  self.device = torch.device("cpu")
 
20
  #self.model_repo = "skodan/resnet-lstm-attention-weights"
21
 
22
  def load(self) -> None:
@@ -91,6 +93,12 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
91
  preprocess=preprocess_cfg
92
  )
93
 
 
 
 
 
 
 
94
  print("Model components loaded successfully.")
95
 
96
  @torch.no_grad()
 
5
  from PIL import Image
6
  import numpy as np
7
  from typing import List, Dict, Any
8
+ from datasets import load_dataset
9
 
10
  from models.resnet_lstm_attention.loader import load_captioning_model
11
  from models.resnet_lstm_attention.retrieval import RetrievalService
 
18
  self.caption_bundle = None
19
  self.retrieval_service = None
20
  self.device = torch.device("cpu")
21
+ self.dataset = None
22
  #self.model_repo = "skodan/resnet-lstm-attention-weights"
23
 
24
  def load(self) -> None:
 
93
  preprocess=preprocess_cfg
94
  )
95
 
96
+ if self.dataset is None:
97
+ print("Loading Flickr8k test split from Hugging Face...")
98
+ ds = load_dataset("jxie/flickr8k")
99
+ self.dataset = ds["train"].concatenate(ds["validation"]).concatenate(ds["test"])
100
+ print(f"Loaded {len(self.dataset)} images/captions from full dataset.")
101
+
102
  print("Model components loaded successfully.")
103
 
104
  @torch.no_grad()
models/resnet_lstm_attention/retrieval.py CHANGED
@@ -2,6 +2,7 @@ import faiss
2
  import pickle
3
  import torch
4
  import numpy as np
 
5
  from PIL import Image
6
  from torchvision import transforms
7
 
@@ -32,20 +33,67 @@ class RetrievalService:
32
 
33
  def _normalize(self, x):
34
  return x / np.linalg.norm(x, axis=1, keepdims=True)
 
35
 
36
- def text_to_image(self, text, top_k=5):
37
- with torch.no_grad():
38
- emb = self.clip_model.encode_text(text).cpu().numpy()
39
- emb = self._normalize(emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- scores, idxs = self.image_index.search(emb, top_k)
42
- return [
43
- {
44
- "image_path": self.image_id_map[i],
45
- "score": float(scores[0][j])
46
- }
47
- for j, i in enumerate(idxs[0])
48
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def image_to_text(self, image: Image.Image, top_k=5):
51
  image = self.image_transform(image).unsqueeze(0)
@@ -58,6 +106,7 @@ class RetrievalService:
58
  print(f"DEBUG: Returning results: {results}")
59
  return results
60
 
 
61
  def text_to_text(self, text: str, top_k: int = 5):
62
  with torch.no_grad():
63
  emb = self.clip_model.encode_text(text).cpu().numpy()
@@ -76,20 +125,66 @@ class RetrievalService:
76
  print(f"DEBUG: Text-to-text results: {results}")
77
  return results
78
 
79
- def image_to_image(self, image: Image.Image, top_k=5):
80
- """
81
- Image → Image retrieval: encode input image, search image index, return image IDs and scores.
82
- """
83
- image = self.image_transform(image).unsqueeze(0).to(self.device)
84
- with torch.no_grad():
85
- emb = self.clip_model.encode_image(image).cpu().numpy()
86
- emb = self._normalize(emb)
87
 
88
- scores, idxs = self.image_index.search(emb, top_k)
89
- return [
90
- {
91
- "image_path": self.image_id_map[i], # integer ID
92
- "score": float(scores[0][j])
93
- }
94
- for j, i in enumerate(idxs[0])
95
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pickle
3
  import torch
4
  import numpy as np
5
+ import os
6
  from PIL import Image
7
  from torchvision import transforms
8
 
 
33
 
34
  def _normalize(self, x):
35
  return x / np.linalg.norm(x, axis=1, keepdims=True)
36
+
37
 
38
+ def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
39
+ raw_results = self.retrieval_service.text_to_image(text, top_k)
40
+
41
+ formatted = []
42
+ for res in raw_results:
43
+ idx = int(res["image_path"]) # the FAISS index (integer)
44
+
45
+ try:
46
+ pil_img = self.dataset[idx]["image"] # directly get PIL.Image
47
+ formatted.append({
48
+ "image": pil_img, # ← pass PIL.Image to UI
49
+ "score": float(res["score"])
50
+ })
51
+ except (IndexError, KeyError):
52
+ formatted.append({
53
+ "image": None,
54
+ "score": float(res["score"])
55
+ })
56
+
57
+ return formatted
58
 
59
+
60
+ # def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
61
+ # raw_results = self.retrieval_service.text_to_image(text, top_k)
62
+
63
+ # formatted = []
64
+ # for res in raw_results:
65
+ # img_id = res["image_path"] # int or str
66
+ # img_id_str = str(img_id)
67
+ # img_filename = f"{img_id_str}.jpg" # always append .jpg, no .endswith
68
+ # full_path = os.path.join("flickr8k_images", img_filename)
69
+
70
+ # if os.path.exists(full_path):
71
+ # formatted.append({
72
+ # "image_path": full_path,
73
+ # "score": float(res["score"])
74
+ # })
75
+ # else:
76
+ # formatted.append({
77
+ # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
78
+ # "score": float(res["score"])
79
+ # })
80
+
81
+ # return formatted
82
+
83
+
84
+ # def text_to_image(self, text, top_k=5):
85
+ # with torch.no_grad():
86
+ # emb = self.clip_model.encode_text(text).cpu().numpy()
87
+ # emb = self._normalize(emb)
88
+
89
+ # scores, idxs = self.image_index.search(emb, top_k)
90
+ # return [
91
+ # {
92
+ # "image_path": self.image_id_map[i],
93
+ # "score": float(scores[0][j])
94
+ # }
95
+ # for j, i in enumerate(idxs[0])
96
+ # ]
97
 
98
  def image_to_text(self, image: Image.Image, top_k=5):
99
  image = self.image_transform(image).unsqueeze(0)
 
106
  print(f"DEBUG: Returning results: {results}")
107
  return results
108
 
109
+
110
  def text_to_text(self, text: str, top_k: int = 5):
111
  with torch.no_grad():
112
  emb = self.clip_model.encode_text(text).cpu().numpy()
 
125
  print(f"DEBUG: Text-to-text results: {results}")
126
  return results
127
 
 
 
 
 
 
 
 
 
128
 
129
+ # def image_to_image(self, image: Image.Image, top_k=5):
130
+ # """
131
+ # Image → Image retrieval: encode input image, search image index, return image IDs and scores.
132
+ # """
133
+ # image = self.image_transform(image).unsqueeze(0).to(self.device)
134
+ # with torch.no_grad():
135
+ # emb = self.clip_model.encode_image(image).cpu().numpy()
136
+ # emb = self._normalize(emb)
137
+
138
+ # scores, idxs = self.image_index.search(emb, top_k)
139
+ # return [
140
+ # {
141
+ # "image_path": self.image_id_map[i], # integer ID
142
+ # "score": float(scores[0][j])
143
+ # }
144
+ # for j, i in enumerate(idxs[0])
145
+ # ]
146
+
147
+ # def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
148
+ # raw_results = self.retrieval_service.image_to_image(image, top_k) # now exists
149
+ # # ... same logic as above ...
150
+
151
+ # formatted = []
152
+ # for res in raw_results:
153
+ # img_id = res["image_path"]
154
+ # img_id_str = str(img_id)
155
+ # img_filename = f"{img_id_str}.jpg"
156
+ # full_path = os.path.join("flickr8k_images", img_filename)
157
+
158
+ # if os.path.exists(full_path):
159
+ # formatted.append({
160
+ # "image_path": full_path,
161
+ # "score": float(res["score"])
162
+ # })
163
+ # else:
164
+ # formatted.append({
165
+ # "image_path": "https://via.placeholder.com/300?text=Not+in+demo",
166
+ # "score": float(res["score"])
167
+ # })
168
+
169
+ # return formatted
170
+
171
+ def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
172
+ raw_results = self.retrieval_service.image_to_image(image, top_k)
173
+
174
+ formatted = []
175
+ for res in raw_results:
176
+ idx = int(res["image_path"])
177
+
178
+ try:
179
+ pil_img = self.dataset[idx]["image"]
180
+ formatted.append({
181
+ "image": pil_img,
182
+ "score": float(res["score"])
183
+ })
184
+ except (IndexError, KeyError):
185
+ formatted.append({
186
+ "image": None,
187
+ "score": float(res["score"])
188
+ })
189
+
190
+ return formatted
requirements.txt CHANGED
@@ -11,4 +11,5 @@ numpy>=1.26.0
11
  altair
12
  pandas
13
  python-multipart>=0.0.9
14
- matplotlib>=3.9.0
 
 
11
  altair
12
  pandas
13
  python-multipart>=0.0.9
14
+ matplotlib>=3.9.0
15
+ datasets>=2.18.0