Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import os | |
| from datetime import datetime | |
| import requests | |
| from datasets import load_dataset | |
| from urllib.parse import quote | |
| # Initialize session state | |
| if 'search_history' not in st.session_state: | |
| st.session_state['search_history'] = [] | |
| if 'search_columns' not in st.session_state: | |
| st.session_state['search_columns'] = [] | |
| if 'initial_search_done' not in st.session_state: | |
| st.session_state['initial_search_done'] = False | |
| if 'dataset' not in st.session_state: | |
| st.session_state['dataset'] = None | |
| class DatasetSearcher: | |
| def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
| self.dataset_id = dataset_id | |
| self.text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.token = os.environ.get('DATASET_KEY') | |
| if not self.token: | |
| st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") | |
| st.stop() | |
| self.load_dataset() | |
| def load_dataset(self): | |
| """Load dataset using the datasets library""" | |
| try: | |
| if st.session_state['dataset'] is None: | |
| with st.spinner("Loading dataset..."): | |
| st.session_state['dataset'] = load_dataset( | |
| self.dataset_id, | |
| token=self.token, | |
| streaming=False | |
| ) | |
| self.dataset = st.session_state['dataset'] | |
| # Convert first split to DataFrame for easier processing | |
| first_split = next(iter(self.dataset.values())) | |
| self.df = pd.DataFrame(first_split) | |
| # Store column information | |
| self.columns = list(self.df.columns) | |
| self.text_columns = [col for col in self.columns | |
| if self.df[col].dtype == 'object' | |
| and not any(term in col.lower() | |
| for term in ['embed', 'vector', 'encoding'])] | |
| # Update session state columns | |
| st.session_state['search_columns'] = self.text_columns | |
| # Prepare text embeddings | |
| self.prepare_features() | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| st.error("Please check your authentication token and internet connection.") | |
| st.stop() | |
| def prepare_features(self): | |
| """Prepare text embeddings for semantic search""" | |
| try: | |
| # Combine text columns for embedding | |
| combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1) | |
| # Create embeddings in batches to manage memory | |
| batch_size = 32 | |
| all_embeddings = [] | |
| with st.spinner("Preparing search features..."): | |
| for i in range(0, len(combined_text), batch_size): | |
| batch = combined_text[i:i+batch_size].tolist() | |
| embeddings = self.text_model.encode(batch) | |
| all_embeddings.append(embeddings) | |
| self.text_embeddings = np.vstack(all_embeddings) | |
| except Exception as e: | |
| st.warning(f"Error preparing features: {str(e)}") | |
| self.text_embeddings = np.random.randn(len(self.df), 384) | |
| def search(self, query, column=None, top_k=20): | |
| """Search the dataset using semantic and keyword matching""" | |
| if self.df.empty: | |
| return [] | |
| # Get semantic similarity scores | |
| query_embedding = self.text_model.encode([query])[0] | |
| similarities = cosine_similarity([query_embedding], self.text_embeddings)[0] | |
| # Get keyword match scores | |
| search_columns = [column] if column and column != "All Fields" else self.text_columns | |
| keyword_scores = np.zeros(len(self.df)) | |
| for col in search_columns: | |
| if col in self.df.columns: | |
| matches = self.df[col].fillna('').str.lower().str.count(query.lower()) | |
| keyword_scores += matches | |
| # Combine scores | |
| combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max())) | |
| # Get top results | |
| top_k = min(top_k, len(combined_scores)) | |
| top_indices = np.argsort(combined_scores)[-top_k:][::-1] | |
| # Format results | |
| results = [] | |
| for idx in top_indices: | |
| result = { | |
| 'relevance_score': float(combined_scores[idx]), | |
| 'semantic_score': float(similarities[idx]), | |
| 'keyword_score': float(keyword_scores[idx]), | |
| **self.df.iloc[idx].to_dict() | |
| } | |
| results.append(result) | |
| return results | |
| def get_dataset_info(self): | |
| """Get information about the dataset""" | |
| if not self.dataset: | |
| return {} | |
| info = { | |
| 'splits': list(self.dataset.keys()), | |
| 'total_rows': sum(split.num_rows for split in self.dataset.values()), | |
| 'columns': self.columns, | |
| 'text_columns': self.text_columns, | |
| 'sample_rows': len(self.df), | |
| 'embeddings_shape': self.text_embeddings.shape | |
| } | |
| return info | |
| def render_video_result(result): | |
| """Render a video result with enhanced display""" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| if 'title' in result: | |
| st.markdown(f"**Title:** {result['title']}") | |
| if 'description' in result: | |
| st.markdown("**Description:**") | |
| st.write(result['description']) | |
| # Show timing information if available | |
| if 'start_time' in result and 'end_time' in result: | |
| st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") | |
| # Show additional metadata | |
| for key, value in result.items(): | |
| if key not in ['title', 'description', 'start_time', 'end_time', 'duration', | |
| 'relevance_score', 'semantic_score', 'keyword_score', | |
| 'video_id', 'youtube_id']: | |
| st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") | |
| with col2: | |
| # Show search scores | |
| st.markdown("**Search Scores:**") | |
| cols = st.columns(3) | |
| cols[0].metric("Overall", f"{result['relevance_score']:.2%}") | |
| cols[1].metric("Semantic", f"{result['semantic_score']:.2%}") | |
| cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches") | |
| # Display video if available | |
| if 'youtube_id' in result: | |
| st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
| def main(): | |
| st.title("π₯ Video Dataset Search") | |
| # Initialize search class | |
| searcher = DatasetSearcher() | |
| # Create tabs | |
| tab1, tab2 = st.tabs(["π Search", "π Dataset Info"]) | |
| # ---- Tab 1: Search ---- | |
| with tab1: | |
| st.subheader("Search Videos") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| query = st.text_input("Search query:", | |
| value="" if st.session_state['initial_search_done'] else "") | |
| with col2: | |
| search_column = st.selectbox("Search in field:", | |
| ["All Fields"] + st.session_state['search_columns']) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| num_results = st.slider("Number of results:", 1, 100, 20) | |
| with col4: | |
| search_button = st.button("π Search") | |
| if search_button and query: | |
| st.session_state['initial_search_done'] = True | |
| selected_column = None if search_column == "All Fields" else search_column | |
| with st.spinner("Searching..."): | |
| results = searcher.search(query, selected_column, num_results) | |
| st.session_state['search_history'].append({ | |
| 'query': query, | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'results': results[:5] | |
| }) | |
| for i, result in enumerate(results, 1): | |
| with st.expander( | |
| f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", | |
| expanded=(i==1) | |
| ): | |
| render_video_result(result) | |
| # ---- Tab 2: Dataset Info ---- | |
| with tab2: | |
| st.subheader("Dataset Information") | |
| info = searcher.get_dataset_info() | |
| if info: | |
| st.write(f"### Dataset: {searcher.dataset_id}") | |
| st.write(f"- Total rows: {info['total_rows']:,}") | |
| st.write(f"- Available splits: {', '.join(info['splits'])}") | |
| st.write(f"- Number of columns: {len(info['columns'])}") | |
| st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}") | |
| st.write("### Sample Data") | |
| st.dataframe(searcher.df.head()) | |
| st.write("### Column Details") | |
| for col in info['columns']: | |
| st.write(f"- **{col}**: {searcher.df[col].dtype}") | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("βοΈ Search History") | |
| if st.button("ποΈ Clear History"): | |
| st.session_state['search_history'] = [] | |
| st.experimental_rerun() | |
| st.markdown("### Recent Searches") | |
| for entry in reversed(st.session_state['search_history'][-5:]): | |
| with st.expander(f"{entry['timestamp']}: {entry['query']}"): | |
| for i, result in enumerate(entry['results'], 1): | |
| st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") | |
| if __name__ == "__main__": | |
| main() |