Spaces:
Runtime error
Runtime error
chore: checking with plt obj
Browse files
app.py
CHANGED
|
@@ -43,9 +43,9 @@ with st.spinner("Generating the heat maps... HOLD ON!"):
|
|
| 43 |
image=preprocessed_img_orig
|
| 44 |
)
|
| 45 |
|
| 46 |
-
utils.plot(attentions=attentions, image=preprocessed_img_orig)
|
| 47 |
|
| 48 |
# Show the attention maps
|
| 49 |
st.title("Attention 🔥 Maps")
|
| 50 |
-
image = Image.
|
| 51 |
st.image(image, caption="Attention Heat Maps")
|
|
|
|
| 43 |
image=preprocessed_img_orig
|
| 44 |
)
|
| 45 |
|
| 46 |
+
plt = utils.plot(attentions=attentions, image=preprocessed_img_orig)
|
| 47 |
|
| 48 |
# Show the attention maps
|
| 49 |
st.title("Attention 🔥 Maps")
|
| 50 |
+
image = Image.frombytes('RGB', plt.canvas.get_width_height(), plt.canvas.tostring_rgb())
|
| 51 |
st.image(image, caption="Attention Heat Maps")
|
utils.py
CHANGED
|
@@ -97,4 +97,4 @@ def plot(attentions, image):
|
|
| 97 |
img_count += 1
|
| 98 |
|
| 99 |
plt.tight_layout()
|
| 100 |
-
plt
|
|
|
|
| 97 |
img_count += 1
|
| 98 |
|
| 99 |
plt.tight_layout()
|
| 100 |
+
return plt
|