import streamlit as st import requests import subprocess import time from PIL import Image import io import base64 # For displaying retrieved images if needed import socket # Start FastAPI server in background def is_port_free(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', port)) != 0 if is_port_free(8001): subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) else: print("Port 8001 in use - skipping backend startup") time.sleep(5) # longer wait API_BASE = "http://localhost:8001" st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide") st.title("Multimodal Retrieval & Captioning System") # Model selection (add more later) model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0) # Common inputs input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"]) image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image") text_input = st.text_input("Text Input") top_k = st.sidebar.slider("Top K", 1, 10, 5) # Tabs for tasks tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([ "Image → Caption", "Text → Image", "Image → Text", "Image → Image", "Text → Text" ]) with tab_caption: if image_file and st.button("Generate Caption"): files = {"file": image_file.getvalue()} data = {"model_name": model_name} resp = requests.post(f"{API_BASE}/caption", files=files, data=data) if resp.status_code == 200: st.write("Caption:", resp.json()["caption"]) else: st.error("Error: " + resp.text) with tab_text2img: if text_input and st.button("Search Images"): data = {"model_name": model_name, "query": text_input, "top_k": top_k} resp = requests.post(f"{API_BASE}/search/text2img", data=data) if resp.status_code == 200: results = resp.json() if results: st.subheader("Retrieved Images") cols = st.columns(3) for idx, res in enumerate(results): with cols[idx % 3]: if res["image"]: img_bytes = base64.b64decode(res["image"]) st.image(img_bytes, width=200) st.caption(f"Score: {res['score']:.3f}") else: st.caption(f"Score: {res['score']:.3f} (Image not available)") else: st.info("No results found.") else: st.error(f"Error: {resp.status_code} - {resp.text}") with tab_img2text: if image_file and st.button("Retrieve Text"): files = {"file": image_file.getvalue()} data = {"model_name": model_name, "top_k": top_k} resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data) if resp.status_code == 200: results = resp.json() if results: st.subheader("Retrieved Texts:") for idx, caption in enumerate(results, 1): st.markdown(f"**{idx}.** {caption}") else: st.info("No results found.") else: st.error(f"Error: {resp.status_code} - {resp.text}") with tab_img2img: if image_file and st.button("Retrieve Similar Images"): files = {"file": image_file.getvalue()} data = {"model_name": model_name, "top_k": top_k} resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) if resp.status_code == 200: results = resp.json() if results: st.subheader("Retrieved Similar Images") cols = st.columns(3) for idx, res in enumerate(results): with cols[idx % 3]: if res["image"]: img_bytes = base64.b64decode(res["image"]) st.image(img_bytes, width=200) st.caption(f"Score: {res['score']:.3f}") else: st.caption(f"Score: {res['score']:.3f} (Image not available)") else: st.info("No results found.") else: st.error(f"Error: {resp.status_code} - {resp.text}") with tab_text2text: text_input_tt = st.text_input("Enter text to find similar captions", placeholder="A child playing with water in the garden") if text_input_tt and st.button("Search Similar Captions"): data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k} resp = requests.post(f"{API_BASE}/search/text2text", data=data) if resp.status_code == 200: results = resp.json() if results: st.subheader("Top similar captions:") for idx, res in enumerate(results, 1): st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`") else: st.info("No similar captions found.") else: st.error(f"Error: {resp.status_code} - {resp.text}") # Old Code # # app.py # import streamlit as st # import requests # import subprocess # import time # from PIL import Image # import io # import base64 # For displaying retrieved images if needed # import socket # # Start FastAPI server in background # # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) # # time.sleep(2) # Wait for server to start # # Check if port is free # def is_port_free(port): # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # return s.connect_ex(('localhost', port)) != 0 # if is_port_free(8001): # subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"]) # else: # print("Port 8001 in use - skipping backend startup") # time.sleep(5) # longer wait # API_BASE = "http://localhost:8001" # st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide") # st.title("Multimodal Retrieval & Captioning System") # # Model selection (add more later) # model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0) # # Common inputs # input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"]) # image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image") # text_input = st.text_input("Text Input") # top_k = st.sidebar.slider("Top K", 1, 10, 5) # # Tabs for tasks # tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([ # "Image → Caption", # "Text → Image", # "Image → Text", # "Image → Image", # "Text → Text" # ]) # with tab_caption: # if image_file and st.button("Generate Caption"): # files = {"file": image_file.getvalue()} # data = {"model_name": model_name} # resp = requests.post(f"{API_BASE}/caption", files=files, data=data) # if resp.status_code == 200: # st.write("Caption:", resp.json()["caption"]) # else: # st.error("Error: " + resp.text) # # with tab_text2img: # # if text_input and st.button("Search Images"): # # data = {"model_name": model_name, "query": text_input, "top_k": top_k} # # resp = requests.post(f"{API_BASE}/search/text2img", data=data) # # if resp.status_code == 200: # # results = resp.json() # # for res in results: # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}") # # else: # # st.error("Error: " + resp.text) # with tab_text2img: # if text_input and st.button("Search Images"): # data = {"model_name": model_name, "query": text_input, "top_k": top_k} # resp = requests.post(f"{API_BASE}/search/text2img", data=data) # if resp.status_code == 200: # results = resp.json() # if results: # st.subheader("Retrieved Images") # cols = st.columns(3) # for idx, res in enumerate(results): # with cols[idx % 3]: # if res["image"] is not None: # st.image(res["image"], width=200) # st.caption(f"Score: {res['score']:.3f}") # if "caption" in res: # if you add caption to results later # st.write(res["caption"]) # else: # st.caption(f"Score: {res['score']:.3f} (Image not found)") # else: # st.info("No results found.") # else: # st.error(f"Error: {resp.status_code} - {resp.text}") # with tab_img2text: # if image_file and st.button("Retrieve Text"): # files = {"file": image_file.getvalue()} # data = {"model_name": model_name, "top_k": top_k} # resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data) # if resp.status_code == 200: # st.write("Retrieved Texts:", resp.json()) # else: # st.error("Error: " + resp.text) # # with tab_img2img: # # if image_file and st.button("Retrieve Similar Images"): # # files = {"file": image_file.getvalue()} # # data = {"model_name": model_name, "top_k": top_k} # # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) # # if resp.status_code == 200: # # results = resp.json() # # for res in results: # # st.image(res["image_path"], caption=f"Score: {res['score']:.3f}") # # else: # # st.error("Error: " + resp.text) # with tab_img2img: # if image_file and st.button("Retrieve Similar Images"): # files = {"file": image_file.getvalue()} # data = {"model_name": model_name, "top_k": top_k} # resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data) # if resp.status_code == 200: # results = resp.json() # if results: # st.subheader("Retrieved Similar Images") # cols = st.columns(3) # for idx, res in enumerate(results): # with cols[idx % 3]: # if "image" in res and res["image"] is not None: # st.image( # res["image"], # width=200, # recommended instead of use_column_width # caption=f"Score: {res['score']:.3f}" # ) # else: # st.caption(f"Score: {res['score']:.3f} (Image not available)") # else: # st.info("No similar images found in the dataset.") # else: # st.error(f"Backend error: {resp.status_code} - {resp.text}") # with tab_text2text: # text_input_tt = st.text_input("Enter text to find similar captions", # placeholder="A child playing with water in the garden") # if text_input_tt and st.button("Search Similar Captions"): # data = {"model_name": model_name, "query": text_input_tt, "top_k": top_k} # resp = requests.post(f"{API_BASE}/search/text2text", data=data) # if resp.status_code == 200: # results = resp.json() # if results: # st.subheader("Top similar captions:") # for idx, res in enumerate(results, 1): # st.markdown(f"**{idx}.** {res['caption']} \nScore: `{res['score']:.4f}`") # else: # st.info("No similar captions found.") # else: # st.error(f"Error: {resp.status_code} - {resp.text}")