Spaces:
Sleeping
Sleeping
| import glob | |
| import os | |
| import streamlit as st | |
| from datastore import ChromaStore | |
| from embeddings import Embedding | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from utils import base64_to_image, image_to_base64 | |
| ##### Image database | |
| root_dir = os.path.join(os.getcwd(), "data") | |
| jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True) | |
| IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files] | |
| def display_image_database(): | |
| image_database_expander = st.expander(label="Image Database") | |
| with image_database_expander: | |
| st.image(IMAGE_DATABASE) | |
| def display_sample_images(): | |
| sample_img_path = os.path.join(os.getcwd(), "sample_imgs") | |
| sample_images = os.listdir(sample_img_path) | |
| images = [] | |
| for i, img in enumerate(sample_images): | |
| images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64))) | |
| st.image(images) | |
| def main(): | |
| st.set_page_config(page_icon="🖼️", page_title="image-search-engine", layout="wide") | |
| st.markdown( | |
| """<h1 style="text-align: center;">🔍️ Image Search Engine</h1>""", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| """<h3 style="text-align: center;">Image to Image search using transformer embeddings</h3>""", | |
| unsafe_allow_html=True, | |
| ) | |
| main_layout = st.columns(2) | |
| with main_layout[0]: | |
| with st.container(border=True, height=550): | |
| st.markdown( | |
| """<h3 style="text-align: center;">Search</h3>""", | |
| unsafe_allow_html=True, | |
| ) | |
| upload_img = st.file_uploader( | |
| label="Query Image", | |
| accept_multiple_files=False, | |
| type=["jpg", "png", "jpeg"], | |
| ) | |
| submit = st.button(label="Submit") | |
| display_sample_images() | |
| with main_layout[1]: | |
| with st.container(border=True, height=550): | |
| st.markdown( | |
| """<h3 style="text-align: center;">Results</h3>""", | |
| unsafe_allow_html=True, | |
| ) | |
| top_k = st.slider(label="Search top k results", min_value=3, max_value=10) | |
| if submit and upload_img: | |
| ## encode uplaoded img | |
| query_embedding = Embedding.encode_image(Image.open(upload_img)) | |
| ## query vectorstore | |
| vectorstore = ChromaStore(collection_name="image_store") | |
| collection = vectorstore.create() | |
| # print(collection) | |
| # print(vectorstore.collection_info(collection)) | |
| st.toast("Vectorstore loaded successfully", icon="✅") | |
| results = vectorstore.query( | |
| collection, | |
| query_embedding, | |
| top_k=top_k, | |
| ) | |
| ## show results | |
| res_images = [] | |
| for res in tqdm(results, desc="Results"): | |
| res_images.append(res[0]) | |
| st.image(res_images) | |
| else: | |
| st.warning("Please upload an image") | |
| display_image_database() | |
| if __name__ == "__main__": | |
| main() | |