Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import json | |
| import os | |
| import glob | |
| from pathlib import Path | |
| from datetime import datetime | |
| import edge_tts | |
| import asyncio | |
| import requests | |
| from collections import defaultdict | |
| from audio_recorder_streamlit import audio_recorder | |
| import streamlit.components.v1 as components | |
| from urllib.parse import quote | |
| from xml.etree import ElementTree as ET | |
| from datasets import load_dataset | |
| # π§ Initialize session state variables | |
| SESSION_VARS = { | |
| 'search_history': [], # Track search history | |
| 'last_voice_input': "", # Last voice input | |
| 'transcript_history': [], # Conversation history | |
| 'should_rerun': False, # Trigger for UI updates | |
| 'search_columns': [], # Available search columns | |
| 'initial_search_done': False, # First search flag | |
| 'tts_voice': "en-US-AriaNeural", # Default voice | |
| 'arxiv_last_query': "", # Last ArXiv search | |
| 'dataset_loaded': False, # Dataset load status | |
| 'current_page': 0, # Current data page | |
| 'data_cache': None, # Data cache | |
| 'dataset_info': None # Dataset metadata | |
| } | |
| # Constants | |
| ROWS_PER_PAGE = 100 | |
| # Initialize session state | |
| for var, default in SESSION_VARS.items(): | |
| if var not in st.session_state: | |
| st.session_state[var] = default | |
| def get_model(): | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| def load_dataset_page(dataset_id, token, page, rows_per_page): | |
| try: | |
| start_idx = page * rows_per_page | |
| end_idx = start_idx + rows_per_page | |
| dataset = load_dataset( | |
| dataset_id, | |
| token=token, | |
| streaming=False, | |
| split=f'train[{start_idx}:{end_idx}]' | |
| ) | |
| return pd.DataFrame(dataset) | |
| except Exception as e: | |
| st.error(f"Error loading page {page}: {str(e)}") | |
| return pd.DataFrame() | |
| def get_dataset_info(dataset_id, token): | |
| try: | |
| dataset = load_dataset(dataset_id, token=token, streaming=True) | |
| return dataset['train'].info | |
| except Exception as e: | |
| st.error(f"Error loading dataset info: {str(e)}") | |
| return None | |
| def fetch_dataset_info(dataset_id): | |
| info_url = f"https://huggingface.co/api/datasets/{dataset_id}" | |
| try: | |
| response = requests.get(info_url, timeout=30) | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception as e: | |
| st.warning(f"Error fetching dataset info: {e}") | |
| return None | |
| def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100): | |
| url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}" | |
| try: | |
| response = requests.get(url, timeout=30) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if 'rows' in data: | |
| processed_rows = [] | |
| for row_data in data['rows']: | |
| row = row_data.get('row', row_data) | |
| # Process embeddings if present | |
| for key in row: | |
| if any(term in key.lower() for term in ['embed', 'vector', 'encoding']): | |
| if isinstance(row[key], str): | |
| try: | |
| row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()] | |
| except: | |
| continue | |
| row['_config'] = config | |
| row['_split'] = split | |
| processed_rows.append(row) | |
| return processed_rows | |
| except Exception as e: | |
| st.warning(f"Error fetching rows: {e}") | |
| return [] | |
| class FastDatasetSearcher: | |
| def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
| self.dataset_id = dataset_id | |
| self.text_model = get_model() | |
| self.token = os.environ.get('DATASET_KEY') | |
| if not self.token: | |
| st.error("Please set the DATASET_KEY environment variable") | |
| st.stop() | |
| if st.session_state['dataset_info'] is None: | |
| st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) | |
| def load_page(self, page=0): | |
| return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
| def quick_search(self, query, df): | |
| """Enhanced search with strict token matching and semantic relevance""" | |
| if df.empty or not query.strip(): | |
| return df | |
| try: | |
| # Define stricter thresholds | |
| MIN_SEMANTIC_SCORE = 0.5 # Higher semantic threshold | |
| EXACT_MATCH_BOOST = 2.0 # Boost for exact matches | |
| # Get searchable columns | |
| searchable_cols = [] | |
| for col in df.columns: | |
| sample_val = df[col].iloc[0] | |
| if not isinstance(sample_val, (np.ndarray, bytes)): | |
| searchable_cols.append(col) | |
| query_lower = query.lower() | |
| query_terms = set(query_lower.split()) | |
| query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
| scores = [] | |
| matched_any = [] | |
| for _, row in df.iterrows(): | |
| text_parts = [] | |
| row_matched = False | |
| exact_match = False | |
| # Prioritize description and matched_text fields | |
| priority_fields = ['description', 'matched_text'] | |
| other_fields = [col for col in searchable_cols if col not in priority_fields] | |
| # First check priority fields for exact matches | |
| for col in priority_fields: | |
| if col in row: | |
| val = row[col] | |
| if val is not None: | |
| val_str = str(val).lower() | |
| # Check for exact token matches | |
| if query_lower in val_str.split(): | |
| exact_match = True | |
| if any(term in val_str.split() for term in query_terms): | |
| row_matched = True | |
| text_parts.append(str(val)) | |
| # Then check other fields | |
| for col in other_fields: | |
| val = row[col] | |
| if val is not None: | |
| val_str = str(val).lower() | |
| if query_lower in val_str.split(): | |
| exact_match = True | |
| if any(term in val_str.split() for term in query_terms): | |
| row_matched = True | |
| text_parts.append(str(val)) | |
| text = ' '.join(text_parts) | |
| if text.strip(): | |
| # Calculate exact token matches | |
| text_tokens = set(text.lower().split()) | |
| matching_terms = query_terms.intersection(text_tokens) | |
| keyword_score = len(matching_terms) / len(query_terms) | |
| # Calculate semantic score | |
| text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
| semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) | |
| # Weighted scoring with priority for exact matches | |
| combined_score = 0.8 * keyword_score + 0.2 * semantic_score | |
| if exact_match: | |
| combined_score *= EXACT_MATCH_BOOST | |
| elif row_matched: | |
| combined_score *= 1.2 | |
| else: | |
| combined_score = 0.0 | |
| row_matched = False | |
| scores.append(combined_score) | |
| matched_any.append(row_matched) | |
| results_df = df.copy() | |
| results_df['score'] = scores | |
| results_df['matched'] = matched_any | |
| # Filter relevant results | |
| filtered_df = results_df[ | |
| (results_df['matched']) | # Include direct matches | |
| (results_df['score'] > MIN_KEYWORD_MATCHES) # Or high relevance | |
| ] | |
| return filtered_df.sort_values('score', ascending=False) | |
| except Exception as e: | |
| st.error(f"Search error: {str(e)}") | |
| return df | |
| class VideoSearch: | |
| def __init__(self): | |
| self.text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.dataset_id = "omegalabsinc/omega-multimodal" | |
| self.load_dataset() | |
| def fetch_dataset_rows(self): | |
| try: | |
| df, configs, splits = search_dataset( | |
| self.dataset_id, | |
| "", | |
| include_configs=None, | |
| include_splits=None | |
| ) | |
| if not df.empty: | |
| st.session_state['search_columns'] = [col for col in df.columns | |
| if col not in ['video_embed', 'description_embed', 'audio_embed'] | |
| and not col.startswith('_')] | |
| return df | |
| return self.load_example_data() | |
| except Exception as e: | |
| st.warning(f"Error loading videos: {e}") | |
| return self.load_example_data() | |
| def load_example_data(self): | |
| example_data = [{ | |
| "video_id": "sample-123", | |
| "youtube_id": "dQw4w9WgXcQ", | |
| "description": "An example video", | |
| "views": 12345, | |
| "start_time": 0, | |
| "end_time": 60 | |
| }] | |
| return pd.DataFrame(example_data) | |
| def load_dataset(self): | |
| self.dataset = self.fetch_dataset_rows() | |
| self.prepare_features() | |
| def prepare_features(self): | |
| try: | |
| embed_cols = [col for col in self.dataset.columns | |
| if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] | |
| embeddings = {} | |
| for col in embed_cols: | |
| try: | |
| data = [] | |
| for row in self.dataset[col]: | |
| if isinstance(row, str): | |
| values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()] | |
| elif isinstance(row, list): | |
| values = row | |
| else: | |
| continue | |
| data.append(values) | |
| if data: | |
| embeddings[col] = np.array(data) | |
| except: | |
| continue | |
| self.video_embeds = embeddings.get('video_embed', next(iter(embeddings.values())) if embeddings else None) | |
| self.text_embeds = embeddings.get('description_embed', self.video_embeds) | |
| except: | |
| num_rows = len(self.dataset) | |
| self.video_embeds = np.random.randn(num_rows, 384) | |
| self.text_embeds = np.random.randn(num_rows, 384) | |
| def search(self, query, column=None, top_k=20): | |
| """Enhanced search with better relevance scoring""" | |
| MIN_RELEVANCE = 0.3 # Minimum relevance threshold | |
| query_embedding = self.text_model.encode([query])[0] | |
| video_sims = cosine_similarity([query_embedding], self.video_embeds)[0] | |
| text_sims = cosine_similarity([query_embedding], self.text_embeds)[0] | |
| combined_sims = 0.7 * text_sims + 0.3 * video_sims # Favor text matches | |
| if column and column in self.dataset.columns and column != "All Fields": | |
| # Direct matches in specified column | |
| matches = self.dataset[column].astype(str).str.contains(query, case=False) | |
| combined_sims[matches] *= 1.5 # Boost exact matches | |
| # Filter by minimum relevance | |
| relevant_indices = np.where(combined_sims >= MIN_RELEVANCE)[0] | |
| if len(relevant_indices) == 0: | |
| return [] | |
| top_k = min(top_k, len(relevant_indices)) | |
| top_indices = relevant_indices[np.argsort(combined_sims[relevant_indices])[-top_k:][::-1]] | |
| results = [] | |
| for idx in top_indices: | |
| result = {'relevance_score': float(combined_sims[idx])} | |
| for col in self.dataset.columns: | |
| if col not in ['video_embed', 'description_embed', 'audio_embed']: | |
| result[col] = self.dataset.iloc[idx][col] | |
| results.append(result) | |
| return results | |
| def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None): | |
| dataset_info = fetch_dataset_info(dataset_id) | |
| if not dataset_info: | |
| return pd.DataFrame(), [], [] | |
| configs = include_configs if include_configs else dataset_info.get('config_names', ['default']) | |
| all_rows = [] | |
| available_splits = set() | |
| for config in configs: | |
| try: | |
| splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}" | |
| splits_response = requests.get(splits_url, timeout=30) | |
| if splits_response.status_code == 200: | |
| splits_data = splits_response.json() | |
| splits = [split['split'] for split in splits_data.get('splits', [])] | |
| if not splits: | |
| splits = ['train'] | |
| if include_splits: | |
| splits = [s for s in splits if s in include_splits] | |
| available_splits.update(splits) | |
| for split in splits: | |
| rows = fetch_dataset_rows(dataset_id, config, split) | |
| for row in rows: | |
| text_content = ' '.join(str(v) for v in row.values() | |
| if isinstance(v, (str, int, float))) | |
| if search_text.lower() in text_content.lower(): | |
| row['_matched_text'] = text_content | |
| row['_relevance_score'] = text_content.lower().count(search_text.lower()) | |
| all_rows.append(row) | |
| except Exception as e: | |
| st.warning(f"Error processing config {config}: {e}") | |
| continue | |
| if all_rows: | |
| df = pd.DataFrame(all_rows) | |
| df = df.sort_values('_relevance_score', ascending=False) | |
| return df, configs, list(available_splits) | |
| return pd.DataFrame(), configs, list(available_splits) | |
| def get_speech_model(): | |
| return edge_tts.Communicate | |
| async def generate_speech(text, voice=None): | |
| if not text.strip(): | |
| return None | |
| if not voice: | |
| voice = st.session_state['tts_voice'] | |
| try: | |
| communicate = get_speech_model()(text, voice) | |
| audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
| await communicate.save(audio_file) | |
| return audio_file | |
| except Exception as e: | |
| st.error(f"Error generating speech: {e}") | |
| return None | |
| def transcribe_audio(audio_path): | |
| """Placeholder for ASR implementation""" | |
| return "ASR not implemented. Add your preferred speech recognition here!" | |
| def arxiv_search(query, max_results=5): | |
| base_url = "http://export.arxiv.org/api/query?" | |
| search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}" | |
| try: | |
| r = requests.get(search_url) | |
| if r.status_code == 200: | |
| root = ET.fromstring(r.text) | |
| ns = {'atom': 'http://www.w3.org/2005/Atom'} | |
| entries = root.findall('atom:entry', ns) | |
| results = [] | |
| for entry in entries: | |
| title = entry.find('atom:title', ns).text.strip() | |
| summary = entry.find('atom:summary', ns).text.strip() | |
| link = next((l.get('href') for l in entry.findall('atom:link', ns) | |
| if l.get('type') == 'text/html'), None) | |
| results.append((title, summary, link)) | |
| return results | |
| except Exception as e: | |
| st.error(f"ArXiv search error: {e}") | |
| return [] | |
| def show_file_manager(): | |
| st.subheader("π File Manager") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3']) | |
| if uploaded_file: | |
| with open(uploaded_file.name, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| st.success(f"Uploaded: {uploaded_file.name}") | |
| st.experimental_rerun() | |
| with col2: | |
| if st.button("π Clear Files"): | |
| for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"): | |
| os.remove(f) | |
| st.success("All files cleared!") | |
| st.experimental_rerun() | |
| files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3") | |
| if files: | |
| st.write("### Existing Files") | |
| for f in files: | |
| with st.expander(f"π {os.path.basename(f)}"): | |
| if f.endswith('.mp3'): | |
| st.audio(f) | |
| else: | |
| with open(f, 'r', encoding='utf-8') as file: | |
| st.text_area("Content", file.read(), height=100) | |
| if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"): | |
| os.remove(f) | |
| st.experimental_rerun() | |
| def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_audio=False): | |
| results = arxiv_search(query, max_results=5) | |
| if not results: | |
| st.write("No results found.") | |
| return | |
| st.markdown(f"**ArXiv Results for '{query}':**") | |
| for i, (title, summary, link) in enumerate(results, start=1): | |
| st.markdown(f"**{i}. {title}**") | |
| st.write(summary) | |
| if link: | |
| st.markdown(f"[View Paper]({link})") | |
| if vocal_summary: | |
| spoken_text = f"Here are ArXiv results for {query}. " | |
| if titles_summary: | |
| spoken_text += " Titles: " + ", ".join([res[0] for res in results]) | |
| else: | |
| spoken_text += " " + results[0][1][:200] | |
| audio_file = asyncio.run(generate_speech(spoken_text)) | |
| if audio_file: | |
| st.audio(audio_file) | |
| if full_audio: | |
| full_text = "" | |
| for i, (title, summary, _) in enumerate(results, start=1): | |
| full_text += f"Result {i}: {title}. {summary} " | |
| audio_file_full = asyncio.run(generate_speech(full_text)) | |
| if audio_file_full: | |
| st.write("### Full Audio Summary") | |
| st.audio(audio_file_full) | |
| def render_result(result): | |
| """Render a search result with voice selection and TTS options""" | |
| score = result.get('relevance_score', 0) | |
| result_filtered = {k: v for k, v in result.items() | |
| if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} | |
| if 'youtube_id' in result: | |
| st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
| cols = st.columns([2, 1]) | |
| with cols[0]: | |
| text_content = [] # Collect text for TTS | |
| for key, value in result_filtered.items(): | |
| if isinstance(value, (str, int, float)): | |
| st.write(f"**{key}:** {value}") | |
| if isinstance(value, str) and len(value.strip()) > 0: | |
| text_content.append(f"{key}: {value}") | |
| with cols[1]: | |
| st.metric("Relevance Score", f"{score:.2%}") | |
| # Voice selection for TTS | |
| voices = { | |
| "Aria (US Female)": "en-US-AriaNeural", | |
| "Guy (US Male)": "en-US-GuyNeural", | |
| "Sonia (UK Female)": "en-GB-SoniaNeural", | |
| "Tony (UK Male)": "en-GB-TonyNeural", | |
| "Jenny (US Female)": "en-US-JennyNeural" | |
| } | |
| selected_voice = st.selectbox( | |
| "Select Voice", | |
| list(voices.keys()), | |
| key=f"voice_{result.get('video_id', '')}" | |
| ) | |
| if st.button("π Read Description", key=f"read_{result.get('video_id', '')}"): | |
| text_to_read = ". ".join(text_content) | |
| audio_file = asyncio.run(generate_speech(text_to_read, voices[selected_voice])) | |
| if audio_file: | |
| st.audio(audio_file) | |
| def main(): | |
| st.title("π₯ Advanced Video & Dataset Search with Voice") | |
| # Initialize search | |
| search = VideoSearch() | |
| # Create tabs | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "π Search", "ποΈ Voice Input", "π ArXiv", "π Files" | |
| ]) | |
| # Search Tab | |
| with tab1: | |
| st.subheader("Search Videos") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| query = st.text_input("Enter search query:", | |
| value="" if st.session_state['initial_search_done'] else "aliens") | |
| with col2: | |
| search_column = st.selectbox("Search in:", | |
| ["All Fields"] + st.session_state['search_columns']) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| num_results = st.slider("Max results:", 1, 100, 20) | |
| with col4: | |
| search_button = st.button("π Search") | |
| if (search_button or not st.session_state['initial_search_done']) and query: | |
| st.session_state['initial_search_done'] = True | |
| selected_column = None if search_column == "All Fields" else search_column | |
| with st.spinner("Searching..."): | |
| results = search.search(query, selected_column, num_results) | |
| if results: | |
| st.session_state['search_history'].append({ | |
| 'query': query, | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'results': results[:5] | |
| }) | |
| st.write(f"Found {len(results)} results:") | |
| for i, result in enumerate(results, 1): | |
| with st.expander(f"Result {i}", expanded=(i==1)): | |
| render_result(result) | |
| else: | |
| st.warning("No matching results found.") | |
| # Voice Input Tab | |
| with tab2: | |
| st.subheader("Voice Search") | |
| st.write("ποΈ Record your query:") | |
| audio_bytes = audio_recorder() | |
| if audio_bytes: | |
| with st.spinner("Processing audio..."): | |
| audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav" | |
| with open(audio_path, "wb") as f: | |
| f.write(audio_bytes) | |
| voice_query = transcribe_audio(audio_path) | |
| st.markdown("**Transcribed Text:**") | |
| st.write(voice_query) | |
| st.session_state['last_voice_input'] = voice_query | |
| if st.button("π Search from Voice"): | |
| results = search.search(voice_query, None, 20) | |
| for i, result in enumerate(results, 1): | |
| with st.expander(f"Result {i}", expanded=(i==1)): | |
| render_result(result) | |
| if os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| # ArXiv Tab | |
| with tab3: | |
| st.subheader("ArXiv Search") | |
| arxiv_query = st.text_input("Search ArXiv:", value=st.session_state['arxiv_last_query']) | |
| vocal_summary = st.checkbox("π Quick Audio Summary", value=True) | |
| titles_summary = st.checkbox("π Titles Only", value=True) | |
| full_audio = st.checkbox("π Full Audio Summary", value=False) | |
| if st.button("π Search ArXiv"): | |
| st.session_state['arxiv_last_query'] = arxiv_query | |
| perform_arxiv_lookup(arxiv_query, vocal_summary, titles_summary, full_audio) | |
| # File Manager Tab | |
| with tab4: | |
| show_file_manager() | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("βοΈ Settings & History") | |
| if st.button("ποΈ Clear History"): | |
| st.session_state['search_history'] = [] | |
| st.experimental_rerun() | |
| st.markdown("### Recent Searches") | |
| for entry in reversed(st.session_state['search_history'][-5:]): | |
| with st.expander(f"{entry['timestamp']}: {entry['query']}"): | |
| for i, result in enumerate(entry['results'], 1): | |
| st.write(f"{i}. {result.get('description', '')[:100]}...") | |
| st.markdown("### Voice Settings") | |
| st.selectbox("TTS Voice:", [ | |
| "en-US-AriaNeural", | |
| "en-US-GuyNeural", | |
| "en-GB-SoniaNeural" | |
| ], key="tts_voice") | |
| if __name__ == "__main__": | |
| main() |