Spaces:
Sleeping
Sleeping
fixed txt2img and img2img error
Browse files- app.py +6 -4
- models/resnet_lstm_attention/model.py +11 -2
app.py
CHANGED
|
@@ -63,8 +63,9 @@ with tab_text2img:
|
|
| 63 |
cols = st.columns(3)
|
| 64 |
for idx, res in enumerate(results):
|
| 65 |
with cols[idx % 3]:
|
| 66 |
-
if res["image"]
|
| 67 |
-
|
|
|
|
| 68 |
st.caption(f"Score: {res['score']:.3f}")
|
| 69 |
else:
|
| 70 |
st.caption(f"Score: {res['score']:.3f} (Image not available)")
|
|
@@ -101,8 +102,9 @@ with tab_img2img:
|
|
| 101 |
cols = st.columns(3)
|
| 102 |
for idx, res in enumerate(results):
|
| 103 |
with cols[idx % 3]:
|
| 104 |
-
if res["image"]
|
| 105 |
-
|
|
|
|
| 106 |
st.caption(f"Score: {res['score']:.3f}")
|
| 107 |
else:
|
| 108 |
st.caption(f"Score: {res['score']:.3f} (Image not available)")
|
|
|
|
| 63 |
cols = st.columns(3)
|
| 64 |
for idx, res in enumerate(results):
|
| 65 |
with cols[idx % 3]:
|
| 66 |
+
if res["image"]:
|
| 67 |
+
img_bytes = base64.b64decode(res["image"])
|
| 68 |
+
st.image(img_bytes, width=200)
|
| 69 |
st.caption(f"Score: {res['score']:.3f}")
|
| 70 |
else:
|
| 71 |
st.caption(f"Score: {res['score']:.3f} (Image not available)")
|
|
|
|
| 102 |
cols = st.columns(3)
|
| 103 |
for idx, res in enumerate(results):
|
| 104 |
with cols[idx % 3]:
|
| 105 |
+
if res["image"]:
|
| 106 |
+
img_bytes = base64.b64decode(res["image"])
|
| 107 |
+
st.image(img_bytes, width=200)
|
| 108 |
st.caption(f"Score: {res['score']:.3f}")
|
| 109 |
else:
|
| 110 |
st.caption(f"Score: {res['score']:.3f} (Image not available)")
|
models/resnet_lstm_attention/model.py
CHANGED
|
@@ -6,6 +6,8 @@ from PIL import Image
|
|
| 6 |
import numpy as np
|
| 7 |
from typing import List, Dict, Any
|
| 8 |
from datasets import load_dataset, concatenate_datasets
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from models.resnet_lstm_attention.loader import load_captioning_model
|
| 11 |
from models.resnet_lstm_attention.retrieval import RetrievalService
|
|
@@ -122,6 +124,13 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
|
|
| 122 |
)
|
| 123 |
return " ".join(tokens)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 126 |
raw_results = self.retrieval_service.text_to_image(text, top_k)
|
| 127 |
|
|
@@ -132,7 +141,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
|
|
| 132 |
try:
|
| 133 |
pil_img = self.dataset[idx]["image"]
|
| 134 |
formatted.append({
|
| 135 |
-
"image": pil_img,
|
| 136 |
"score": float(res["score"])
|
| 137 |
})
|
| 138 |
except (IndexError, KeyError):
|
|
@@ -156,7 +165,7 @@ class ResNetLSTMAttentionModel(UnifiedModelInterface):
|
|
| 156 |
try:
|
| 157 |
pil_img = self.dataset[idx]["image"]
|
| 158 |
formatted.append({
|
| 159 |
-
"image": pil_img,
|
| 160 |
"score": float(res["score"])
|
| 161 |
})
|
| 162 |
except (IndexError, KeyError):
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from typing import List, Dict, Any
|
| 8 |
from datasets import load_dataset, concatenate_datasets
|
| 9 |
+
import io
|
| 10 |
+
import base64
|
| 11 |
|
| 12 |
from models.resnet_lstm_attention.loader import load_captioning_model
|
| 13 |
from models.resnet_lstm_attention.retrieval import RetrievalService
|
|
|
|
| 124 |
)
|
| 125 |
return " ".join(tokens)
|
| 126 |
|
| 127 |
+
def _pil_to_base64(self, image: Image.Image) -> str:
|
| 128 |
+
"""Convert PIL image to base64 string for JSON serialization."""
|
| 129 |
+
buffered = io.BytesIO()
|
| 130 |
+
image.save(buffered, format="JPEG")
|
| 131 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 135 |
raw_results = self.retrieval_service.text_to_image(text, top_k)
|
| 136 |
|
|
|
|
| 141 |
try:
|
| 142 |
pil_img = self.dataset[idx]["image"]
|
| 143 |
formatted.append({
|
| 144 |
+
"image": self._pil_to_base64(pil_img),
|
| 145 |
"score": float(res["score"])
|
| 146 |
})
|
| 147 |
except (IndexError, KeyError):
|
|
|
|
| 165 |
try:
|
| 166 |
pil_img = self.dataset[idx]["image"]
|
| 167 |
formatted.append({
|
| 168 |
+
"image": self._pil_to_base64(pil_img),
|
| 169 |
"score": float(res["score"])
|
| 170 |
})
|
| 171 |
except (IndexError, KeyError):
|