skodan's picture
fixed txt2img and img2img error
7612210
import streamlit as st
import requests
import subprocess
import time
from PIL import Image
import io
import base64 # For displaying retrieved images if needed
import socket
# Start FastAPI server in background
def is_port_free(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) != 0
if is_port_free(8001):
subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
else:
print("Port 8001 in use - skipping backend startup")
time.sleep(5) # longer wait
API_BASE = "http://localhost:8001"
st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide")
st.title("Multimodal Retrieval & Captioning System")
# Model selection (add more later)
model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0)
# Common inputs
input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"])
image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image")
text_input = st.text_input("Text Input")
top_k = st.sidebar.slider("Top K", 1, 10, 5)
# Tabs for tasks
tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([
"Image β†’ Caption",
"Text β†’ Image",
"Image β†’ Text",
"Image β†’ Image",
"Text β†’ Text"
])
with tab_caption:
if image_file and st.button("Generate Caption"):
files = {"file": image_file.getvalue()}
data = {"model_name": model_name}
resp = requests.post(f"{API_BASE}/caption", files=files, data=data)
if resp.status_code == 200:
st.write("Caption:", resp.json()["caption"])
else:
st.error("Error: " + resp.text)
with tab_text2img:
if text_input and st.button("Search Images"):
data = {"model_name": model_name, "query": text_input, "top_k": top_k}
resp = requests.post(f"{API_BASE}/search/text2img", data=data)
if resp.status_code == 200:
results = resp.json()
if results:
st.subheader("Retrieved Images")
cols = st.columns(3)
for idx, res in enumerate(results):
with cols[idx % 3]:
if res["image"]:
img_bytes = base64.b64decode(res["image"])
st.image(img_bytes, width=200)
st.caption(f"Score: {res['score']:.3f}")
else:
st.caption(f"Score: {res['score']:.3f} (Image not available)")
else:
st.info("No results found.")
else:
st.error(f"Error: {resp.status_code} - {resp.text}")
with tab_img2text:
if image_file and st.button("Retrieve Text"):
files = {"file": image_file.getvalue()}
data = {"model_name": model_name, "top_k": top_k}
resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
if resp.status_code == 200:
results = resp.json()
if results:
st.subheader("Retrieved Texts:")
for idx, caption in enumerate(results, 1):
st.markdown(f"**{idx}.** {caption}")
else:
st.info("No results found.")
else:
st.error(f"Error: {resp.status_code} - {resp.text}")
with tab_img2img:
if image_file and st.button("Retrieve Similar Images"):
files = {"file": image_file.getvalue()}
data = {"model_name": model_name, "top_k": top_k}
resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
if resp.status_code == 200:
results = resp.json()
if results:
st.subheader("Retrieved Similar Images")
cols = st.columns(3)
for idx, res in enumerate(results):
with cols[idx % 3]:
if res["image"]:
img_bytes = base64.b64decode(res["image"])
st.image(img_bytes, width=200)
st.caption(f"Score: {res['score']:.3f}")
else:
st.caption(f"Score: {res['score']:.3f} (Image not available)")
else:
st.info("No results found.")
else:
st.error(f"Error: {resp.status_code} - {resp.text}")
with tab_text2text:
text_input_tt = st.text_input("Enter text to find similar captions",
placeholder="A child playing with water in the garden")
if text_input_tt and st.button("Search Similar Captions"):
data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k}
resp = requests.post(f"{API_BASE}/search/text2text", data=data)
if resp.status_code == 200:
results = resp.json()
if results:
st.subheader("Top similar captions:")
for idx, res in enumerate(results, 1):
st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`")
else:
st.info("No similar captions found.")
else:
st.error(f"Error: {resp.status_code} - {resp.text}")
# Old Code
# # app.py
# import streamlit as st
# import requests
# import subprocess
# import time
# from PIL import Image
# import io
# import base64 # For displaying retrieved images if needed
# import socket
# # Start FastAPI server in background
# # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
# # time.sleep(2) # Wait for server to start
# # Check if port is free
# def is_port_free(port):
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# return s.connect_ex(('localhost', port)) != 0
# if is_port_free(8001):
# subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
# else:
# print("Port 8001 in use - skipping backend startup")
# time.sleep(5) # longer wait
# API_BASE = "http://localhost:8001"
# st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide")
# st.title("Multimodal Retrieval & Captioning System")
# # Model selection (add more later)
# model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0)
# # Common inputs
# input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"])
# image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image")
# text_input = st.text_input("Text Input")
# top_k = st.sidebar.slider("Top K", 1, 10, 5)
# # Tabs for tasks
# tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([
# "Image β†’ Caption",
# "Text β†’ Image",
# "Image β†’ Text",
# "Image β†’ Image",
# "Text β†’ Text"
# ])
# with tab_caption:
# if image_file and st.button("Generate Caption"):
# files = {"file": image_file.getvalue()}
# data = {"model_name": model_name}
# resp = requests.post(f"{API_BASE}/caption", files=files, data=data)
# if resp.status_code == 200:
# st.write("Caption:", resp.json()["caption"])
# else:
# st.error("Error: " + resp.text)
# # with tab_text2img:
# # if text_input and st.button("Search Images"):
# # data = {"model_name": model_name, "query": text_input, "top_k": top_k}
# # resp = requests.post(f"{API_BASE}/search/text2img", data=data)
# # if resp.status_code == 200:
# # results = resp.json()
# # for res in results:
# # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
# # else:
# # st.error("Error: " + resp.text)
# with tab_text2img:
# if text_input and st.button("Search Images"):
# data = {"model_name": model_name, "query": text_input, "top_k": top_k}
# resp = requests.post(f"{API_BASE}/search/text2img", data=data)
# if resp.status_code == 200:
# results = resp.json()
# if results:
# st.subheader("Retrieved Images")
# cols = st.columns(3)
# for idx, res in enumerate(results):
# with cols[idx % 3]:
# if res["image"] is not None:
# st.image(res["image"], width=200)
# st.caption(f"Score: {res['score']:.3f}")
# if "caption" in res: # if you add caption to results later
# st.write(res["caption"])
# else:
# st.caption(f"Score: {res['score']:.3f} (Image not found)")
# else:
# st.info("No results found.")
# else:
# st.error(f"Error: {resp.status_code} - {resp.text}")
# with tab_img2text:
# if image_file and st.button("Retrieve Text"):
# files = {"file": image_file.getvalue()}
# data = {"model_name": model_name, "top_k": top_k}
# resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
# if resp.status_code == 200:
# st.write("Retrieved Texts:", resp.json())
# else:
# st.error("Error: " + resp.text)
# # with tab_img2img:
# # if image_file and st.button("Retrieve Similar Images"):
# # files = {"file": image_file.getvalue()}
# # data = {"model_name": model_name, "top_k": top_k}
# # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
# # if resp.status_code == 200:
# # results = resp.json()
# # for res in results:
# # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
# # else:
# # st.error("Error: " + resp.text)
# with tab_img2img:
# if image_file and st.button("Retrieve Similar Images"):
# files = {"file": image_file.getvalue()}
# data = {"model_name": model_name, "top_k": top_k}
# resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
# if resp.status_code == 200:
# results = resp.json()
# if results:
# st.subheader("Retrieved Similar Images")
# cols = st.columns(3)
# for idx, res in enumerate(results):
# with cols[idx % 3]:
# if "image" in res and res["image"] is not None:
# st.image(
# res["image"],
# width=200, # recommended instead of use_column_width
# caption=f"Score: {res['score']:.3f}"
# )
# else:
# st.caption(f"Score: {res['score']:.3f} (Image not available)")
# else:
# st.info("No similar images found in the dataset.")
# else:
# st.error(f"Backend error: {resp.status_code} - {resp.text}")
# with tab_text2text:
# text_input_tt = st.text_input("Enter text to find similar captions",
# placeholder="A child playing with water in the garden")
# if text_input_tt and st.button("Search Similar Captions"):
# data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k}
# resp = requests.post(f"{API_BASE}/search/text2text", data=data)
# if resp.status_code == 200:
# results = resp.json()
# if results:
# st.subheader("Top similar captions:")
# for idx, res in enumerate(results, 1):
# st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`")
# else:
# st.info("No similar captions found.")
# else:
# st.error(f"Error: {resp.status_code} - {resp.text}")