skodan commited on
Commit
7612210
·
1 Parent(s): 48b6a6b

fixed txt2img and img2img error

Browse files
Files changed (2) hide show
  1. app.py +6 -4
  2. models/resnet_lstm_attention/model.py +11 -2
app.py CHANGED
@@ -63,8 +63,9 @@ with tab_text2img:
63
  cols = st.columns(3)
64
  for idx, res in enumerate(results):
65
  with cols[idx % 3]:
66
- if res["image"] is not None:
67
- st.image(res["image"], width=200)
 
68
  st.caption(f"Score: {res['score']:.3f}")
69
  else:
70
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
@@ -101,8 +102,9 @@ with tab_img2img:
101
  cols = st.columns(3)
102
  for idx, res in enumerate(results):
103
  with cols[idx % 3]:
104
- if res["image"] is not None:
105
- st.image(res["image"], width=200)
 
106
  st.caption(f"Score: {res['score']:.3f}")
107
  else:
108
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
 
63
  cols = st.columns(3)
64
  for idx, res in enumerate(results):
65
  with cols[idx % 3]:
66
+ if res["image"]:
67
+ img_bytes = base64.b64decode(res["image"])
68
+ st.image(img_bytes, width=200)
69
  st.caption(f"Score: {res['score']:.3f}")
70
  else:
71
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
 
102
  cols = st.columns(3)
103
  for idx, res in enumerate(results):
104
  with cols[idx % 3]:
105
+ if res["image"]:
106
+ img_bytes = base64.b64decode(res["image"])
107
+ st.image(img_bytes, width=200)
108
  st.caption(f"Score: {res['score']:.3f}")
109
  else:
110
  st.caption(f"Score: {res['score']:.3f} (Image not available)")
models/resnet_lstm_attention/model.py CHANGED
@@ -6,6 +6,8 @@ from PIL import Image
6
  import numpy as np
7
  from typing import List, Dict, Any
8
  from datasets import load_dataset, concatenate_datasets
 
 
9
 
10
  from models.resnet_lstm_attention.loader import load_captioning_model
11
  from models.resnet_lstm_attention.retrieval import RetrievalService
@@ -122,6 +124,13 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
122
  )
123
  return " ".join(tokens)
124
 
 
 
 
 
 
 
 
125
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
126
  raw_results = self.retrieval_service.text_to_image(text, top_k)
127
 
@@ -132,7 +141,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
132
  try:
133
  pil_img = self.dataset[idx]["image"]
134
  formatted.append({
135
- "image": pil_img,
136
  "score": float(res["score"])
137
  })
138
  except (IndexError, KeyError):
@@ -156,7 +165,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
156
  try:
157
  pil_img = self.dataset[idx]["image"]
158
  formatted.append({
159
- "image": pil_img,
160
  "score": float(res["score"])
161
  })
162
  except (IndexError, KeyError):
 
6
  import numpy as np
7
  from typing import List, Dict, Any
8
  from datasets import load_dataset, concatenate_datasets
9
+ import io
10
+ import base64
11
 
12
  from models.resnet_lstm_attention.loader import load_captioning_model
13
  from models.resnet_lstm_attention.retrieval import RetrievalService
 
124
  )
125
  return " ".join(tokens)
126
 
127
+ def _pil_to_base64(self, image: Image.Image) -> str:
128
+ """Convert PIL image to base64 string for JSON serialization."""
129
+ buffered = io.BytesIO()
130
+ image.save(buffered, format="JPEG")
131
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
132
+
133
+
134
  def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
135
  raw_results = self.retrieval_service.text_to_image(text, top_k)
136
 
 
141
  try:
142
  pil_img = self.dataset[idx]["image"]
143
  formatted.append({
144
+ "image": self._pil_to_base64(pil_img),
145
  "score": float(res["score"])
146
  })
147
  except (IndexError, KeyError):
 
165
  try:
166
  pil_img = self.dataset[idx]["image"]
167
  formatted.append({
168
+ "image": self._pil_to_base64(pil_img),
169
  "score": float(res["score"])
170
  })
171
  except (IndexError, KeyError):