Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import subprocess | |
| import time | |
| from PIL import Image | |
| import io | |
| import base64 # For displaying retrieved images if needed | |
| import socket | |
| # Start FastAPI server in background | |
| def is_port_free(port): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| return s.connect_ex(('localhost', port)) != 0 | |
| if is_port_free(8001): | |
| subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) | |
| else: | |
| print("Port 8001 in use - skipping backend startup") | |
| time.sleep(5) # longer wait | |
| API_BASE = "http://localhost:8001" | |
| st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide") | |
| st.title("Multimodal Retrieval & Captioning System") | |
| # Model selection (add more later) | |
| model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0) | |
| # Common inputs | |
| input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"]) | |
| image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image") | |
| text_input = st.text_input("Text Input") | |
| top_k = st.sidebar.slider("Top K", 1, 10, 5) | |
| # Tabs for tasks | |
| tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([ | |
| "Image β Caption", | |
| "Text β Image", | |
| "Image β Text", | |
| "Image β Image", | |
| "Text β Text" | |
| ]) | |
| with tab_caption: | |
| if image_file and st.button("Generate Caption"): | |
| files = {"file": image_file.getvalue()} | |
| data = {"model_name": model_name} | |
| resp = requests.post(f"{API_BASE}/caption", files=files, data=data) | |
| if resp.status_code == 200: | |
| st.write("Caption:", resp.json()["caption"]) | |
| else: | |
| st.error("Error: " + resp.text) | |
| with tab_text2img: | |
| if text_input and st.button("Search Images"): | |
| data = {"model_name": model_name, "query": text_input, "top_k": top_k} | |
| resp = requests.post(f"{API_BASE}/search/text2img", data=data) | |
| if resp.status_code == 200: | |
| results = resp.json() | |
| if results: | |
| st.subheader("Retrieved Images") | |
| cols = st.columns(3) | |
| for idx, res in enumerate(results): | |
| with cols[idx % 3]: | |
| if res["image"]: | |
| img_bytes = base64.b64decode(res["image"]) | |
| st.image(img_bytes, width=200) | |
| st.caption(f"Score: {res['score']:.3f}") | |
| else: | |
| st.caption(f"Score: {res['score']:.3f} (Image not available)") | |
| else: | |
| st.info("No results found.") | |
| else: | |
| st.error(f"Error: {resp.status_code} - {resp.text}") | |
| with tab_img2text: | |
| if image_file and st.button("Retrieve Text"): | |
| files = {"file": image_file.getvalue()} | |
| data = {"model_name": model_name, "top_k": top_k} | |
| resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data) | |
| if resp.status_code == 200: | |
| results = resp.json() | |
| if results: | |
| st.subheader("Retrieved Texts:") | |
| for idx, caption in enumerate(results, 1): | |
| st.markdown(f"**{idx}.** {caption}") | |
| else: | |
| st.info("No results found.") | |
| else: | |
| st.error(f"Error: {resp.status_code} - {resp.text}") | |
| with tab_img2img: | |
| if image_file and st.button("Retrieve Similar Images"): | |
| files = {"file": image_file.getvalue()} | |
| data = {"model_name": model_name, "top_k": top_k} | |
| resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) | |
| if resp.status_code == 200: | |
| results = resp.json() | |
| if results: | |
| st.subheader("Retrieved Similar Images") | |
| cols = st.columns(3) | |
| for idx, res in enumerate(results): | |
| with cols[idx % 3]: | |
| if res["image"]: | |
| img_bytes = base64.b64decode(res["image"]) | |
| st.image(img_bytes, width=200) | |
| st.caption(f"Score: {res['score']:.3f}") | |
| else: | |
| st.caption(f"Score: {res['score']:.3f} (Image not available)") | |
| else: | |
| st.info("No results found.") | |
| else: | |
| st.error(f"Error: {resp.status_code} - {resp.text}") | |
| with tab_text2text: | |
| text_input_tt = st.text_input("Enter text to find similar captions", | |
| placeholder="A child playing with water in the garden") | |
| if text_input_tt and st.button("Search Similar Captions"): | |
| data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k} | |
| resp = requests.post(f"{API_BASE}/search/text2text", data=data) | |
| if resp.status_code == 200: | |
| results = resp.json() | |
| if results: | |
| st.subheader("Top similar captions:") | |
| for idx, res in enumerate(results, 1): | |
| st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`") | |
| else: | |
| st.info("No similar captions found.") | |
| else: | |
| st.error(f"Error: {resp.status_code} - {resp.text}") | |
| # Old Code | |
| # # app.py | |
| # import streamlit as st | |
| # import requests | |
| # import subprocess | |
| # import time | |
| # from PIL import Image | |
| # import io | |
| # import base64 # For displaying retrieved images if needed | |
| # import socket | |
| # # Start FastAPI server in background | |
| # # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) | |
| # # time.sleep(2) # Wait for server to start | |
| # # Check if port is free | |
| # def is_port_free(port): | |
| # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| # return s.connect_ex(('localhost', port)) != 0 | |
| # if is_port_free(8001): | |
| # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) | |
| # else: | |
| # print("Port 8001 in use - skipping backend startup") | |
| # time.sleep(5) # longer wait | |
| # API_BASE = "http://localhost:8001" | |
| # st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide") | |
| # st.title("Multimodal Retrieval & Captioning System") | |
| # # Model selection (add more later) | |
| # model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0) | |
| # # Common inputs | |
| # input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"]) | |
| # image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image") | |
| # text_input = st.text_input("Text Input") | |
| # top_k = st.sidebar.slider("Top K", 1, 10, 5) | |
| # # Tabs for tasks | |
| # tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([ | |
| # "Image β Caption", | |
| # "Text β Image", | |
| # "Image β Text", | |
| # "Image β Image", | |
| # "Text β Text" | |
| # ]) | |
| # with tab_caption: | |
| # if image_file and st.button("Generate Caption"): | |
| # files = {"file": image_file.getvalue()} | |
| # data = {"model_name": model_name} | |
| # resp = requests.post(f"{API_BASE}/caption", files=files, data=data) | |
| # if resp.status_code == 200: | |
| # st.write("Caption:", resp.json()["caption"]) | |
| # else: | |
| # st.error("Error: " + resp.text) | |
| # # with tab_text2img: | |
| # # if text_input and st.button("Search Images"): | |
| # # data = {"model_name": model_name, "query": text_input, "top_k": top_k} | |
| # # resp = requests.post(f"{API_BASE}/search/text2img", data=data) | |
| # # if resp.status_code == 200: | |
| # # results = resp.json() | |
| # # for res in results: | |
| # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}") | |
| # # else: | |
| # # st.error("Error: " + resp.text) | |
| # with tab_text2img: | |
| # if text_input and st.button("Search Images"): | |
| # data = {"model_name": model_name, "query": text_input, "top_k": top_k} | |
| # resp = requests.post(f"{API_BASE}/search/text2img", data=data) | |
| # if resp.status_code == 200: | |
| # results = resp.json() | |
| # if results: | |
| # st.subheader("Retrieved Images") | |
| # cols = st.columns(3) | |
| # for idx, res in enumerate(results): | |
| # with cols[idx % 3]: | |
| # if res["image"] is not None: | |
| # st.image(res["image"], width=200) | |
| # st.caption(f"Score: {res['score']:.3f}") | |
| # if "caption" in res: # if you add caption to results later | |
| # st.write(res["caption"]) | |
| # else: | |
| # st.caption(f"Score: {res['score']:.3f} (Image not found)") | |
| # else: | |
| # st.info("No results found.") | |
| # else: | |
| # st.error(f"Error: {resp.status_code} - {resp.text}") | |
| # with tab_img2text: | |
| # if image_file and st.button("Retrieve Text"): | |
| # files = {"file": image_file.getvalue()} | |
| # data = {"model_name": model_name, "top_k": top_k} | |
| # resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data) | |
| # if resp.status_code == 200: | |
| # st.write("Retrieved Texts:", resp.json()) | |
| # else: | |
| # st.error("Error: " + resp.text) | |
| # # with tab_img2img: | |
| # # if image_file and st.button("Retrieve Similar Images"): | |
| # # files = {"file": image_file.getvalue()} | |
| # # data = {"model_name": model_name, "top_k": top_k} | |
| # # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) | |
| # # if resp.status_code == 200: | |
| # # results = resp.json() | |
| # # for res in results: | |
| # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}") | |
| # # else: | |
| # # st.error("Error: " + resp.text) | |
| # with tab_img2img: | |
| # if image_file and st.button("Retrieve Similar Images"): | |
| # files = {"file": image_file.getvalue()} | |
| # data = {"model_name": model_name, "top_k": top_k} | |
| # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) | |
| # if resp.status_code == 200: | |
| # results = resp.json() | |
| # if results: | |
| # st.subheader("Retrieved Similar Images") | |
| # cols = st.columns(3) | |
| # for idx, res in enumerate(results): | |
| # with cols[idx % 3]: | |
| # if "image" in res and res["image"] is not None: | |
| # st.image( | |
| # res["image"], | |
| # width=200, # recommended instead of use_column_width | |
| # caption=f"Score: {res['score']:.3f}" | |
| # ) | |
| # else: | |
| # st.caption(f"Score: {res['score']:.3f} (Image not available)") | |
| # else: | |
| # st.info("No similar images found in the dataset.") | |
| # else: | |
| # st.error(f"Backend error: {resp.status_code} - {resp.text}") | |
| # with tab_text2text: | |
| # text_input_tt = st.text_input("Enter text to find similar captions", | |
| # placeholder="A child playing with water in the garden") | |
| # if text_input_tt and st.button("Search Similar Captions"): | |
| # data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k} | |
| # resp = requests.post(f"{API_BASE}/search/text2text", data=data) | |
| # if resp.status_code == 200: | |
| # results = resp.json() | |
| # if results: | |
| # st.subheader("Top similar captions:") | |
| # for idx, res in enumerate(results, 1): | |
| # st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`") | |
| # else: | |
| # st.info("No similar captions found.") | |
| # else: | |
| # st.error(f"Error: {resp.status_code} - {resp.text}") |