Spaces:
No application file
No application file
Update user.py
Browse files
user.py
CHANGED
|
@@ -62,4 +62,80 @@ def home():
|
|
| 62 |
message_u = {"role": "user", "content": user_input}
|
| 63 |
st.session_state.messages.append(message_u)
|
| 64 |
st.session_state.messages.append(message)
|
|
|
|
|
|
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
message_u = {"role": "user", "content": user_input}
|
| 63 |
st.session_state.messages.append(message_u)
|
| 64 |
st.session_state.messages.append(message)
|
| 65 |
+
display_images(user_input)
|
| 66 |
+
display_videos_streamlit(user_input)
|
| 67 |
|
| 68 |
+
|
| 69 |
+
def display_images(image_collection, query_text, max_distance=None, debug=False):
|
| 70 |
+
"""
|
| 71 |
+
Display images in a Streamlit app based on a query.
|
| 72 |
+
Args:
|
| 73 |
+
image_collection: The image collection object for querying.
|
| 74 |
+
query_text (str): The text query for images.
|
| 75 |
+
max_distance (float, optional): Maximum allowable distance for filtering.
|
| 76 |
+
debug (bool, optional): Whether to print debug information.
|
| 77 |
+
"""
|
| 78 |
+
results = image_collection.query(
|
| 79 |
+
query_texts=[query_text],
|
| 80 |
+
n_results=10,
|
| 81 |
+
include=['uris', 'distances']
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
uris = results['uris'][0]
|
| 85 |
+
distances = results['distances'][0]
|
| 86 |
+
|
| 87 |
+
# Combine uris and distances, then sort by URI in ascending order
|
| 88 |
+
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
|
| 89 |
+
|
| 90 |
+
# Display images side by side, 3 images per row
|
| 91 |
+
cols = st.columns(3) # Create 3 columns for the layout
|
| 92 |
+
|
| 93 |
+
for i, (uri, distance) in enumerate(sorted_results):
|
| 94 |
+
if max_distance is None or distance <= max_distance:
|
| 95 |
+
try:
|
| 96 |
+
img = PILImage.open(uri)
|
| 97 |
+
with cols[i % 3]: # Use modulo to cycle through columns
|
| 98 |
+
st.image(img, use_container_width = True)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
st.error(f"Error loading image: {e}")
|
| 101 |
+
|
| 102 |
+
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
|
| 103 |
+
"""
|
| 104 |
+
Display videos in a Streamlit app based on a query.
|
| 105 |
+
Args:
|
| 106 |
+
video_collection: The video collection object for querying.
|
| 107 |
+
query_text (str): The text query for videos.
|
| 108 |
+
max_distance (float, optional): Maximum allowable distance for filtering.
|
| 109 |
+
max_results (int, optional): Maximum number of results to display.
|
| 110 |
+
debug (bool, optional): Whether to print debug information.
|
| 111 |
+
"""
|
| 112 |
+
# Deduplication set
|
| 113 |
+
displayed_videos = set()
|
| 114 |
+
|
| 115 |
+
# Query the video collection with the specified text
|
| 116 |
+
results = video_collection.query(
|
| 117 |
+
query_texts=[query_text],
|
| 118 |
+
n_results=max_results, # Adjust the number of results if needed
|
| 119 |
+
include=['uris', 'distances', 'metadatas']
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Extract URIs, distances, and metadatas from the result
|
| 123 |
+
uris = results['uris'][0]
|
| 124 |
+
distances = results['distances'][0]
|
| 125 |
+
metadatas = results['metadatas'][0]
|
| 126 |
+
|
| 127 |
+
# Display the videos that meet the distance criteria
|
| 128 |
+
for uri, distance, metadata in zip(uris, distances, metadatas):
|
| 129 |
+
video_uri = metadata['video_uri']
|
| 130 |
+
|
| 131 |
+
# Check if a max_distance filter is applied and the distance is within the allowed range
|
| 132 |
+
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
|
| 133 |
+
if debug:
|
| 134 |
+
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
|
| 135 |
+
st.video(video_uri) # Display video in Streamlit
|
| 136 |
+
displayed_videos.add(video_uri) # Add to the set to prevent duplication
|
| 137 |
+
else:
|
| 138 |
+
if debug:
|
| 139 |
+
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
|
| 140 |
+
|
| 141 |
+
|