Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import json | |
| import os | |
| import glob | |
| import random | |
| from pathlib import Path | |
| from datetime import datetime, timedelta | |
| import edge_tts | |
| import asyncio | |
| import requests | |
| from collections import defaultdict | |
| import streamlit.components.v1 as components | |
| from urllib.parse import quote | |
| from xml.etree import ElementTree as ET | |
| from datasets import load_dataset | |
| import base64 | |
| import re | |
| # -------------------- Configuration & Constants -------------------- | |
| # User name assignment | |
| USER_NAMES = [ | |
| "Alex", "Jordan", "Taylor", "Morgan", "Rowan", "Avery", "Riley", "Quinn", | |
| "Casey", "Jesse", "Reese", "Skyler", "Ellis", "Devon", "Aubrey", "Kendall", | |
| "Parker", "Dakota", "Sage", "Finley" | |
| ] | |
| ROWS_PER_PAGE = 100 | |
| MIN_SEARCH_SCORE = 0.3 | |
| EXACT_MATCH_BOOST = 2.0 | |
| SAVED_INPUTS_DIR = "saved_inputs" | |
| os.makedirs(SAVED_INPUTS_DIR, exist_ok=True) | |
| # -------------------- Session State Initialization -------------------- | |
| SESSION_VARS = { | |
| 'search_history': [], | |
| 'last_voice_input': "", | |
| 'transcript_history': [], | |
| 'should_rerun': False, | |
| 'search_columns': [], | |
| 'initial_search_done': False, | |
| 'tts_voice': "en-US-AriaNeural", | |
| 'arxiv_last_query': "", | |
| 'dataset_loaded': False, | |
| 'current_page': 0, | |
| 'data_cache': None, | |
| 'dataset_info': None, | |
| 'nps_submitted': False, | |
| 'nps_last_shown': None, | |
| 'old_val': None, | |
| 'voice_text': None, | |
| 'user_name': None, # New: Track user name | |
| 'max_items': 100 # Default max items | |
| } | |
| for var, default in SESSION_VARS.items(): | |
| if var not in st.session_state: | |
| st.session_state[var] = default | |
| # Assign user name if not assigned | |
| if st.session_state['user_name'] is None: | |
| st.session_state['user_name'] = random.choice(USER_NAMES) | |
| # -------------------- Utility Functions -------------------- | |
| def create_voice_component(): | |
| """Create the voice input component""" | |
| mycomponent = components.declare_component( | |
| "mycomponent", | |
| path="mycomponent" | |
| ) | |
| return mycomponent | |
| def clean_for_speech(text: str) -> str: | |
| text = text.replace("\n", " ") | |
| text = text.replace("</s>", " ") | |
| text = text.replace("#", "") | |
| text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0): | |
| """Generate audio using Edge TTS""" | |
| text = clean_for_speech(text) | |
| if not text.strip(): | |
| return None | |
| rate_str = f"{rate:+d}%" | |
| pitch_str = f"{pitch:+d}Hz" | |
| communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str) | |
| out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
| await communicate.save(out_fn) | |
| return out_fn | |
| def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0): | |
| return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch)) | |
| def play_and_download_audio(file_path): | |
| """Play and provide download link for audio""" | |
| if file_path and os.path.exists(file_path): | |
| st.audio(file_path) | |
| dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' | |
| st.markdown(dl_link, unsafe_allow_html=True) | |
| def get_model(): | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| def load_dataset_page(dataset_id, token, page, rows_per_page): | |
| try: | |
| start_idx = page * rows_per_page | |
| end_idx = start_idx + rows_per_page | |
| dataset = load_dataset( | |
| dataset_id, | |
| token=token, | |
| streaming=False, | |
| split=f'train[{start_idx}:{end_idx}]' | |
| ) | |
| return pd.DataFrame(dataset) | |
| except Exception as e: | |
| st.error(f"Error loading page {page}: {str(e)}") | |
| return pd.DataFrame() | |
| def get_dataset_info(dataset_id, token): | |
| try: | |
| dataset = load_dataset(dataset_id, token=token, streaming=True) | |
| return dataset['train'].info | |
| except Exception as e: | |
| st.error(f"Error loading dataset info: {str(e)}") | |
| return None | |
| def fetch_dataset_info(dataset_id): | |
| info_url = f"https://huggingface.co/api/datasets/{dataset_id}" | |
| try: | |
| response = requests.get(info_url, timeout=30) | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception as e: | |
| st.warning(f"Error fetching dataset info: {e}") | |
| return None | |
| def generate_filename(text): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() | |
| safe_text = re.sub(r'[-\s]+', '-', safe_text) | |
| return f"{timestamp}_{safe_text}.md" | |
| def save_input_as_md(text): | |
| if not text.strip(): | |
| return | |
| fn = generate_filename(text) | |
| full_path = os.path.join(SAVED_INPUTS_DIR, fn) | |
| with open(full_path, 'w', encoding='utf-8') as f: | |
| f.write(f"# User: {st.session_state['user_name']}\n") | |
| f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
| f.write(text) | |
| return full_path | |
| def list_saved_inputs(): | |
| files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md"))) | |
| return files | |
| def render_result(result): | |
| score = result.get('relevance_score', 0) | |
| result_filtered = {k: v for k, v in result.items() | |
| if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} | |
| if 'youtube_id' in result: | |
| st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
| cols = st.columns([2, 1]) | |
| with cols[0]: | |
| text_content = [] | |
| for key, value in result_filtered.items(): | |
| if isinstance(value, (str, int, float)): | |
| st.write(f"**{key}:** {value}") | |
| if isinstance(value, str) and len(value.strip()) > 0: | |
| text_content.append(f"{key}: {value}") | |
| with cols[1]: | |
| st.metric("Relevance", f"{score:.2%}") | |
| voices = { | |
| "Aria (US Female)": "en-US-AriaNeural", | |
| "Guy (US Male)": "en-US-GuyNeural", | |
| "Sonia (UK Female)": "en-GB-SoniaNeural", | |
| "Tony (UK Male)": "en-GB-TonyNeural" | |
| } | |
| selected_voice = st.selectbox( | |
| "Voice:", | |
| list(voices.keys()), | |
| key=f"voice_{result.get('video_id', '')}" | |
| ) | |
| if st.button("π Read", key=f"read_{result.get('video_id', '')}"): | |
| text_to_read = ". ".join(text_content) | |
| audio_file = speak_with_edge_tts(text_to_read, voices[selected_voice]) | |
| if audio_file: | |
| play_and_download_audio(audio_file) | |
| class FastDatasetSearcher: | |
| def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
| self.dataset_id = dataset_id | |
| self.text_model = get_model() | |
| self.token = os.environ.get('DATASET_KEY') | |
| if not self.token: | |
| st.error("Please set the DATASET_KEY environment variable") | |
| st.stop() | |
| if st.session_state['dataset_info'] is None: | |
| st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) | |
| def load_page(self, page=0): | |
| return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
| def quick_search(self, query, df): | |
| if df.empty or not query.strip(): | |
| return df | |
| try: | |
| searchable_cols = [] | |
| for col in df.columns: | |
| sample_val = df[col].iloc[0] if len(df) > 0 else "" | |
| if not isinstance(sample_val, (np.ndarray, bytes)): | |
| searchable_cols.append(col) | |
| query_lower = query.lower() | |
| query_terms = set(query_lower.split()) | |
| query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
| scores = [] | |
| matched_any = [] | |
| for _, row in df.iterrows(): | |
| text_parts = [] | |
| row_matched = False | |
| exact_match = False | |
| priority_fields = ['description', 'matched_text'] | |
| other_fields = [col for col in searchable_cols if col not in priority_fields] | |
| for col in priority_fields: | |
| if col in row: | |
| val = row[col] | |
| if val is not None: | |
| val_str = str(val).lower() | |
| if query_lower in val_str.split(): | |
| exact_match = True | |
| if any(term in val_str.split() for term in query_terms): | |
| row_matched = True | |
| text_parts.append(str(val)) | |
| for col in other_fields: | |
| val = row[col] | |
| if val is not None: | |
| val_str = str(val).lower() | |
| if query_lower in val_str.split(): | |
| exact_match = True | |
| if any(term in val_str.split() for term in query_terms): | |
| row_matched = True | |
| text_parts.append(str(val)) | |
| text = ' '.join(text_parts) | |
| if text.strip(): | |
| text_tokens = set(text.lower().split()) | |
| matching_terms = query_terms.intersection(text_tokens) | |
| keyword_score = len(matching_terms) / len(query_terms) if len(query_terms) > 0 else 0.0 | |
| text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
| semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) | |
| combined_score = 0.7 * keyword_score + 0.3 * semantic_score | |
| if exact_match: | |
| combined_score *= EXACT_MATCH_BOOST | |
| elif row_matched: | |
| combined_score *= 1.2 | |
| else: | |
| combined_score = 0.0 | |
| row_matched = False | |
| scores.append(combined_score) | |
| matched_any.append(row_matched) | |
| results_df = df.copy() | |
| results_df['score'] = scores | |
| results_df['matched'] = matched_any | |
| filtered_df = results_df[ | |
| (results_df['matched']) | | |
| (results_df['score'] > MIN_SEARCH_SCORE) | |
| ] | |
| return filtered_df.sort_values('score', ascending=False) | |
| except Exception as e: | |
| st.error(f"Search error: {str(e)}") | |
| return df | |
| # -------------------- Main App -------------------- | |
| def main(): | |
| st.title("π₯ Smart Video & Voice Search") | |
| # Load saved inputs (conversation history) | |
| saved_files = list_saved_inputs() | |
| # Initialize components | |
| voice_component = create_voice_component() | |
| search = FastDatasetSearcher() | |
| # Voice input at top level | |
| voice_val = voice_component(my_input_value="Start speaking...") | |
| # User can override max items | |
| with st.sidebar: | |
| st.write(f"**Current User:** {st.session_state['user_name']}") | |
| st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items']) | |
| st.subheader("π Saved Inputs:") | |
| # Show saved md files in order | |
| for fpath in saved_files: | |
| fname = os.path.basename(fpath) | |
| st.write(f"- [{fname}]({fpath})") | |
| if voice_val: | |
| voice_text = str(voice_val).strip() | |
| edited_input = st.text_area("βοΈ Edit Voice Input:", value=voice_text, height=100) | |
| # Auto-run default True now | |
| run_option = st.selectbox("Select Search Type:", | |
| ["Quick Search", "Deep Search", "Voice Summary"]) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| autorun = st.checkbox("β‘ Auto-Run", value=True) | |
| with col2: | |
| full_audio = st.checkbox("π Full Audio", value=False) | |
| input_changed = (voice_text != st.session_state.get('old_val')) | |
| if autorun and input_changed: | |