Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from io import BytesIO | |
| import streamlit as st | |
| import base64 | |
| # 1. Define the Qdrant collection name that we used to store all of our metadata and vectors | |
| collection_name = "animal_images" | |
| #2. Set up a state variable that we'll reuse throughout the rest of the app | |
| if 'selected_record' not in st.session_state: | |
| st.session_state.selected_record = None | |
| def set_selected_record(_new_record): | |
| #3. Create a function that allows us to easily set the selected record value. | |
| st.session_state.selected_record = _new_record | |
| def get_client(): | |
| #4. create the Qrant client> these must be set up in the .steamlit/secrets.toml file | |
| return QdrantClient( | |
| url=st.secrets.get("qdrant_db_url"), | |
| api_key=st.secrets.get("qdrant_api_key") | |
| ) | |
| def get_initial_records(): | |
| #5. when the app first starts, let's show a small sample of images to the user. | |
| client = get_client() | |
| records, _ = client.scroll( | |
| collection_name = collection_name, | |
| with_vectors=False, | |
| limit=12 | |
| ) | |
| return records | |
| def get_similar_records(): | |
| # if user has selected a record then they want to see similar images | |
| client = get_client() | |
| if st.session_state.selected_record is not None: | |
| return client.recommend( | |
| collection_name=collection_name, | |
| positive=[st.session_state.selected_record.id], | |
| limit=12 | |
| ) | |
| return records | |
| def get_bytes_from_base64(base64_string): | |
| return BytesIO(base64.b64decode(base64_string)) | |
| records = get_similar_records( | |
| ) if st.session_state.selected_record is not None else get_initial_records() | |
| # 9 if we have a selected record then show that image at the top of the screen. | |
| if st.session_state.selected_record: | |
| image_bytes = get_bytes_from_base64( | |
| st.session_state.selected_record.payload["base64"]) | |
| st.header("Images similar to:") | |
| st.image( | |
| image=image_bytes | |
| ) | |
| st.divider() | |
| #10 Setup the grid that we will use to render out images | |
| column = st.columns(3) | |
| #11. Iternate over all the fetch records form the DB and render to a preview of each image using the base64 string | |
| for idx, record in enumerate(records): | |
| col_idx = idx % 3 | |
| image_bytes = get_bytes_from_base64(record.payload["base64"]) | |
| with column[col_idx]: | |
| st.image( | |
| image=image_bytes | |
| ) | |
| st.button( | |
| label="Find similar images", | |
| key=record.id, | |
| on_click=set_selected_record, | |
| args=[record] | |
| ) |