Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import chromadb | |
| import cv2 | |
| import edge_tts | |
| import gradio as gr | |
| import torch | |
| import yt_dlp | |
| from huggingface_hub import InferenceClient | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import BlipForConditionalGeneration, BlipProcessor, pipeline | |
| # Try to import spaces for ZeroGPU support | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| def get_inference_token() -> str | None: | |
| """Get token for HuggingFace Inference API from environment.""" | |
| return os.environ.get("HF_TOKEN") | |
| # Global embedding model (shared - stateless) | |
| _embedding_model = None | |
| def get_embedding_model(): | |
| global _embedding_model | |
| if _embedding_model is None: | |
| _embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| return _embedding_model | |
| # Session state class for per-user storage | |
| class SessionState: | |
| """Per-session state including ChromaDB collection.""" | |
| def __init__(self, session_id: str | None = None): | |
| self.session_id = session_id or uuid.uuid4().hex | |
| self._client = chromadb.Client() | |
| self._collection = self._client.get_or_create_collection( | |
| name=f"video_knowledge_{self.session_id}", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| def collection(self): | |
| return self._collection | |
| def clear(self): | |
| """Clear and recreate the collection.""" | |
| try: | |
| self._client.delete_collection(f"video_knowledge_{self.session_id}") | |
| except Exception: | |
| pass | |
| self._collection = self._client.get_or_create_collection( | |
| name=f"video_knowledge_{self.session_id}", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| def create_session_state(): | |
| """Create a new session state with random ID.""" | |
| return SessionState(uuid.uuid4().hex) | |
| # Default collection for backward compatibility (used by tests) | |
| _default_client = chromadb.Client() | |
| collection = _default_client.get_or_create_collection( | |
| name="video_knowledge", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| def get_device(): | |
| # Use CPU to avoid ZeroGPU duration limits | |
| return "cpu" | |
| def get_whisper_model(): | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base", | |
| device=get_device(), | |
| ) | |
| def get_vision_model(): | |
| device = get_device() | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ).to(device) | |
| return processor, model | |
| # Chat models - tested and working with HF Inference API | |
| CHAT_MODELS = [ | |
| "Qwen/Qwen2.5-72B-Instruct", # Primary - works with token | |
| "meta-llama/Llama-3.1-70B-Instruct", # Fallback | |
| ] | |
| def get_proxy_url() -> str | None: | |
| """Get proxy URL from environment for YouTube downloads.""" | |
| return os.environ.get("PROXY_URL") | |
| def download_video(url: str, output_dir: str) -> list[dict]: | |
| """Download video from YouTube URL (video or playlist).""" | |
| ydl_opts = { | |
| "format": "best[height<=720]/best", | |
| "outtmpl": os.path.join(output_dir, "%(title)s.%(ext)s"), | |
| "quiet": True, | |
| "no_warnings": True, | |
| "ignoreerrors": True, | |
| "retries": 3, | |
| } | |
| # Add proxy if configured | |
| proxy_url = get_proxy_url() | |
| if proxy_url: | |
| ydl_opts["proxy"] = proxy_url | |
| downloaded = [] | |
| try: | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(url, download=True) | |
| if info is None: | |
| raise ValueError("Could not extract video information") | |
| if "entries" in info: | |
| for entry in info["entries"]: | |
| if entry: | |
| ext = entry.get("ext", "mp4") | |
| downloaded.append({ | |
| "title": entry.get("title", "Unknown"), | |
| "path": os.path.join(output_dir, f"{entry['title']}.{ext}"), | |
| "duration": entry.get("duration", 0), | |
| }) | |
| else: | |
| ext = info.get("ext", "mp4") | |
| downloaded.append({ | |
| "title": info.get("title", "Unknown"), | |
| "path": os.path.join(output_dir, f"{info['title']}.{ext}"), | |
| "duration": info.get("duration", 0), | |
| }) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download video: {e!s}") from e | |
| return downloaded | |
| def extract_audio(video_path: str, output_dir: str) -> str: | |
| """Extract audio from video file.""" | |
| audio_path = os.path.join(output_dir, "audio.mp3") | |
| try: | |
| result = subprocess.run( | |
| ["ffmpeg", "-i", video_path, "-vn", "-acodec", "libmp3lame", | |
| "-q:a", "2", audio_path, "-y"], | |
| capture_output=True, | |
| timeout=300, | |
| ) | |
| if result.returncode != 0 and not os.path.exists(audio_path): | |
| raise RuntimeError(f"FFmpeg failed: {result.stderr.decode()}") | |
| except subprocess.TimeoutExpired as e: | |
| raise RuntimeError("Audio extraction timed out") from e | |
| except FileNotFoundError as e: | |
| raise RuntimeError("FFmpeg not found. Please install FFmpeg.") from e | |
| return audio_path | |
| def extract_frames(video_path: str, num_frames: int = 5) -> list[Image.Image]: | |
| """Extract evenly spaced frames from video.""" | |
| frames = [] | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return frames | |
| indices = [int(i * total_frames / (num_frames + 1)) for i in range(1, num_frames + 1)] | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| def describe_frame(image: Image.Image, processor, model) -> str: | |
| """Generate caption for a single frame.""" | |
| device = get_device() | |
| inputs = processor(image, return_tensors="pt").to(device) | |
| output = model.generate(**inputs, max_new_tokens=50) | |
| return processor.decode(output[0], skip_special_tokens=True) | |
| # Text-to-Speech settings | |
| TTS_VOICE = "en-US-AriaNeural" # Natural female voice for edge-tts | |
| # Try to load Parler-TTS for SOTA quality (requires GPU) | |
| _parler_model = None | |
| _parler_tokenizer = None | |
| def get_parler_model(): | |
| """Lazy load Parler-TTS model.""" | |
| global _parler_model, _parler_tokenizer | |
| if _parler_model is None: | |
| try: | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from transformers import AutoTokenizer | |
| device = get_device() | |
| _parler_model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| "parler-tts/parler-tts-mini-v1" | |
| ).to(device) | |
| _parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") | |
| except Exception as e: | |
| print(f"Could not load Parler-TTS: {e}") | |
| return None, None | |
| return _parler_model, _parler_tokenizer | |
| async def _edge_tts_async(text: str, output_path: str) -> str: | |
| """Convert text to speech using edge-tts (async).""" | |
| communicate = edge_tts.Communicate(text, TTS_VOICE) | |
| await communicate.save(output_path) | |
| return output_path | |
| def text_to_speech_parler(text: str) -> str | None: | |
| """Convert text to speech using Parler-TTS (SOTA quality, requires GPU).""" | |
| import soundfile as sf | |
| model, tokenizer = get_parler_model() | |
| if model is None: | |
| return None | |
| try: | |
| device = get_device() | |
| # Natural female voice description | |
| description = "A female speaker with a clear, natural voice speaks at a moderate pace with a warm tone." | |
| input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) | |
| prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) | |
| generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| output_path = f.name | |
| sf.write(output_path, audio_arr, model.config.sampling_rate) | |
| return output_path | |
| except Exception as e: | |
| print(f"Parler-TTS error: {e}") | |
| return None | |
| def text_to_speech(text: str, use_parler: bool = False) -> str | None: | |
| """Convert text to speech and return audio file path. | |
| Args: | |
| text: Text to convert to speech | |
| use_parler: If True and GPU available, use Parler-TTS for SOTA quality. | |
| Otherwise uses Edge-TTS (faster, no GPU needed). | |
| """ | |
| if not text or not text.strip(): | |
| return None | |
| # Clean up text for TTS (remove markdown formatting) | |
| clean_text = text.replace("**", "").replace("*", "").replace("`", "") | |
| clean_text = clean_text.replace("\n\n", ". ").replace("\n", " ") | |
| # Remove source/model info lines from TTS | |
| lines = clean_text.split(". ") | |
| lines = [l for l in lines if not l.startswith("Sources:") and not l.startswith("Model:")] | |
| clean_text = ". ".join(lines) | |
| if not clean_text.strip(): | |
| return None | |
| # Try Parler-TTS if requested and GPU available | |
| if use_parler and torch.cuda.is_available(): | |
| result = text_to_speech_parler(clean_text) | |
| if result: | |
| return result | |
| # Fall back to edge-tts | |
| # Use edge-tts (fast, no GPU needed) | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: | |
| output_path = f.name | |
| asyncio.run(_edge_tts_async(clean_text, output_path)) | |
| return output_path | |
| except Exception as e: | |
| print(f"TTS error: {e}") | |
| return None | |
| def transcribe_voice_input(audio_path: str) -> str: | |
| """Transcribe voice input using Whisper.""" | |
| if audio_path is None or not os.path.exists(audio_path): | |
| return "" | |
| try: | |
| whisper = get_whisper_model() | |
| result = whisper(audio_path, return_timestamps=True) | |
| return result.get("text", "").strip() | |
| except Exception as e: | |
| print(f"Voice transcription error: {e}") | |
| return "" | |
| def transcribe_audio(audio_path: str, whisper_model) -> str: | |
| """Transcribe audio file using Whisper.""" | |
| if not os.path.exists(audio_path): | |
| return "" | |
| result = whisper_model(audio_path, return_timestamps=True) | |
| return result["text"] | |
| def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: | |
| """Split text into overlapping chunks.""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = " ".join(words[i:i + chunk_size]) | |
| if chunk: | |
| chunks.append(chunk) | |
| return chunks | |
| def add_to_vector_db( | |
| title: str, | |
| transcript: str, | |
| visual_contexts: list[str], | |
| session_state=None, | |
| ): | |
| """Add video content to vector database.""" | |
| embed_model = get_embedding_model() | |
| coll = session_state.collection if session_state else collection | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| # Add transcript chunks | |
| if transcript: | |
| chunks = chunk_text(transcript) | |
| for i, chunk in enumerate(chunks): | |
| documents.append(chunk) | |
| metadatas.append({ | |
| "title": title, | |
| "type": "transcript", | |
| "chunk_index": i, | |
| }) | |
| ids.append(f"{title}_transcript_{i}_{uuid.uuid4().hex[:8]}") | |
| # Add visual context | |
| for i, context in enumerate(visual_contexts): | |
| documents.append(f"Visual scene from {title}: {context}") | |
| metadatas.append({ | |
| "title": title, | |
| "type": "visual", | |
| "frame_index": i, | |
| }) | |
| ids.append(f"{title}_visual_{i}_{uuid.uuid4().hex[:8]}") | |
| if documents: | |
| embeddings = embed_model.encode(documents).tolist() | |
| coll.add( | |
| documents=documents, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ids=ids, | |
| ) | |
| return len(documents) | |
| def search_knowledge(query, n_results=5, session_state=None): | |
| """Search the vector database for relevant content.""" | |
| embed_model = get_embedding_model() | |
| coll = session_state.collection if session_state else collection | |
| query_embedding = embed_model.encode([query]).tolist() | |
| results = coll.query( | |
| query_embeddings=query_embedding, | |
| n_results=n_results, | |
| ) | |
| matches = [] | |
| if results["documents"] and results["documents"][0]: | |
| for doc, metadata in zip(results["documents"][0], results["metadatas"][0]): | |
| matches.append({ | |
| "content": doc, | |
| "title": metadata.get("title", "Unknown"), | |
| "type": metadata.get("type", "unknown"), | |
| }) | |
| return matches | |
| def is_valid_youtube_url(url: str) -> tuple[bool, str]: | |
| """Validate and normalize YouTube URL.""" | |
| url = url.strip() | |
| if not url: | |
| return False, "Please enter a YouTube URL." | |
| # Common YouTube URL patterns | |
| valid_patterns = [ | |
| "youtube.com/watch", | |
| "youtube.com/playlist", | |
| "youtube.com/shorts", | |
| "youtu.be/", | |
| "youtube.com/embed", | |
| "youtube.com/v/", | |
| ] | |
| if not any(pattern in url.lower() for pattern in valid_patterns): | |
| if "youtube" in url.lower() or "youtu" in url.lower(): | |
| return False, "Invalid YouTube URL format. Please use a full video or playlist URL." | |
| return False, "Please enter a valid YouTube URL (e.g., https://youtube.com/watch?v=...)" | |
| if not url.startswith(("http://", "https://")): | |
| url = "https://" + url | |
| return True, url | |
| def _process_youtube_impl(url, num_frames, session_state=None, progress=gr.Progress()): | |
| """Internal implementation of video processing.""" | |
| is_valid, result = is_valid_youtube_url(url) | |
| if not is_valid: | |
| return result | |
| url = result # Use normalized URL | |
| try: | |
| progress(0, desc="Loading models...") | |
| whisper_model = get_whisper_model() | |
| vision_processor, vision_model = get_vision_model() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| progress(0.1, desc="Downloading video...") | |
| downloaded = download_video(url.strip(), tmpdir) | |
| results = [] | |
| total = len(downloaded) | |
| for i, item in enumerate(downloaded): | |
| base_progress = 0.1 + 0.8 * (i / total) | |
| video_result = [f"## {item['title']}"] | |
| video_files = list(Path(tmpdir).glob("*.mp4")) + \ | |
| list(Path(tmpdir).glob("*.webm")) + \ | |
| list(Path(tmpdir).glob("*.mkv")) | |
| if not video_files: | |
| video_result.append("*No video file found*") | |
| results.append("\n\n".join(video_result)) | |
| continue | |
| video_path = str(video_files[0]) | |
| # Extract and transcribe audio | |
| progress(base_progress + 0.2 * (1/total), desc=f"Extracting audio: {item['title']}") | |
| audio_path = extract_audio(video_path, tmpdir) | |
| progress(base_progress + 0.4 * (1/total), desc=f"Transcribing: {item['title']}") | |
| transcript = transcribe_audio(audio_path, whisper_model) | |
| visual_contexts = [] | |
| if transcript: | |
| video_result.append("### Transcript") | |
| video_result.append(transcript) | |
| # Analyze frames (always enabled for better context) | |
| progress(base_progress + 0.6 * (1/total), desc=f"Analyzing frames: {item['title']}") | |
| frames = extract_frames(video_path, num_frames) | |
| if frames: | |
| video_result.append("\n### Visual Context") | |
| for j, frame in enumerate(frames): | |
| caption = describe_frame(frame, vision_processor, vision_model) | |
| visual_contexts.append(caption) | |
| video_result.append(f"**Frame {j+1}:** {caption}") | |
| # Store in vector DB (session-specific) | |
| progress(base_progress + 0.8 * (1/total), desc=f"Storing in knowledge base: {item['title']}") | |
| num_stored = add_to_vector_db(item["title"], transcript, visual_contexts, session_state) | |
| video_result.append(f"\n*Added {num_stored} chunks to knowledge base*") | |
| results.append("\n\n".join(video_result)) | |
| progress(1.0, desc="Done!") | |
| if results: | |
| summary = "\n\n---\n\n".join(results) | |
| summary += "\n\n---\n\n**Analysis complete!** You can now ask me questions about this video." | |
| return summary | |
| return "No content found to analyze." | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "unavailable" in error_msg.lower(): | |
| return "Video unavailable. It may be private, age-restricted, or removed." | |
| if "copyright" in error_msg.lower(): | |
| return "Video blocked due to copyright restrictions." | |
| return f"Error analyzing video: {error_msg}" | |
| # Run on CPU to avoid ZeroGPU duration limits | |
| def process_youtube(url, num_frames, session_state=None, progress=gr.Progress()): | |
| return _process_youtube_impl(url, num_frames, session_state, progress) | |
| def chat_with_videos(message, history, session_state=None): | |
| # Get inference token from environment | |
| token = get_inference_token() | |
| if token is None: | |
| return "No API token configured. Please ask the Space owner to set HF_TOKEN." | |
| if not message or not message.strip(): | |
| return "Please enter a question." | |
| # Use session-specific collection | |
| coll = session_state.collection if session_state else collection | |
| # Check if we have any content in the knowledge base | |
| if coll.count() == 0: | |
| return "No videos have been analyzed yet. Please analyze some videos first to build the knowledge base." | |
| # Search for relevant context | |
| matches = search_knowledge(message.strip(), n_results=5, session_state=session_state) | |
| if not matches: | |
| return "I couldn't find any relevant information in the analyzed videos." | |
| # Build context from matches | |
| context_parts = [] | |
| for match in matches: | |
| source = f"[{match['title']} - {match['type']}]" | |
| context_parts.append(f"{source}: {match['content']}") | |
| context = "\n\n".join(context_parts) | |
| # Generate response using HF Inference API with fallback models | |
| client = InferenceClient(token=token) | |
| system_prompt = """You are a helpful assistant that answers questions about video content. | |
| You have access to transcripts and visual descriptions from analyzed videos. | |
| Answer based only on the provided context. If the context doesn't contain enough information, say so. | |
| Be concise but thorough.""" | |
| user_prompt = f"""Based on the following video content, answer the question. | |
| Video Content: | |
| {context} | |
| Question: {message}""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| last_error = None | |
| used_model = None | |
| for model in CHAT_MODELS: | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=1024, | |
| ) | |
| answer = response.choices[0].message.content | |
| used_model = model.split("/")[-1] # Get model name without org | |
| break | |
| except Exception as e: | |
| last_error = e | |
| continue | |
| else: | |
| # All models failed | |
| error_msg = str(last_error) if last_error else "Unknown error" | |
| if "otp" in error_msg.lower(): | |
| return "Authentication error. Please check HF_TOKEN configuration." | |
| if "401" in error_msg or "unauthorized" in error_msg.lower(): | |
| return "Authentication error. Please try logging out and back in." | |
| if "429" in error_msg or "rate" in error_msg.lower(): | |
| return "Rate limit exceeded. Please wait a moment and try again." | |
| if "503" in error_msg or "unavailable" in error_msg.lower(): | |
| return "Model service temporarily unavailable. Please try again later." | |
| return f"Could not generate response. Error: {error_msg}" | |
| # Add sources and model info | |
| sources = list(set(m["title"] for m in matches)) | |
| answer += f"\n\n*Sources: {', '.join(sources)}*" | |
| answer += f"\n*Model: {used_model}*" | |
| return answer | |
| def get_knowledge_stats(session_state=None): | |
| """Get statistics about the knowledge base.""" | |
| coll = session_state.collection if session_state else collection | |
| count = coll.count() | |
| if count == 0: | |
| return "**Knowledge base is empty.** Paste a YouTube URL to get started!" | |
| # Get unique video titles | |
| try: | |
| all_data = coll.get(include=["metadatas"]) | |
| titles = set() | |
| for meta in all_data["metadatas"]: | |
| if meta and "title" in meta: | |
| titles.add(meta["title"]) | |
| video_list = ", ".join(sorted(titles)[:5]) | |
| if len(titles) > 5: | |
| video_list += f", ... (+{len(titles) - 5} more)" | |
| return f"**{count}** chunks from **{len(titles)}** videos: {video_list}" | |
| except Exception: | |
| return f"**{count}** chunks in knowledge base" | |
| def get_analyzed_videos(session_state=None): | |
| """Get list of analyzed video titles.""" | |
| coll = session_state.collection if session_state else collection | |
| try: | |
| all_data = coll.get(include=["metadatas"]) | |
| titles = set() | |
| for meta in all_data["metadatas"]: | |
| if meta and "title" in meta: | |
| titles.add(meta["title"]) | |
| return sorted(titles) | |
| except Exception: | |
| return [] | |
| def clear_session_knowledge(session_state=None): | |
| """Clear all data from the session's knowledge base.""" | |
| if session_state is None: | |
| return "No session found." | |
| try: | |
| session_state.clear() | |
| return "Knowledge base cleared!" | |
| except Exception as e: | |
| return f"Error clearing: {e!s}" | |
| # Keep for backward compatibility with tests | |
| def clear_knowledge_base() -> str: | |
| """Clear all data from the default knowledge base (for tests).""" | |
| global collection | |
| try: | |
| _default_client.delete_collection("video_knowledge") | |
| collection = _default_client.get_or_create_collection( | |
| name="video_knowledge", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| return "Knowledge base cleared successfully!" | |
| except Exception as e: | |
| return f"Error clearing knowledge base: {e!s}" | |
| def handle_chat(message, history, session_state, progress=gr.Progress()): | |
| """Unified chat handler that processes URLs or answers questions.""" | |
| history = history or [] | |
| # Create session state if needed | |
| if session_state is None: | |
| session_state = create_session_state() | |
| if not message or not message.strip(): | |
| return history, "", session_state | |
| # Add user message to history | |
| history.append({"role": "user", "content": message}) | |
| # Check if we have HF_TOKEN configured | |
| token = get_inference_token() | |
| if token is None: | |
| history.append({ | |
| "role": "assistant", | |
| "content": "No API token available. Please ask the Space owner to configure HF_TOKEN." | |
| }) | |
| return history, "", session_state | |
| message = message.strip() | |
| # Check if it's a YouTube URL | |
| is_url, normalized = is_valid_youtube_url(message) | |
| if is_url: | |
| # Process the YouTube video | |
| history.append({ | |
| "role": "assistant", | |
| "content": "I'll analyze that video for you. This may take a few minutes..." | |
| }) | |
| try: | |
| result = process_youtube(normalized, 5, session_state, progress) | |
| # Summarize the result for chat | |
| if "Error" in result or "Please" in result: | |
| history.append({"role": "assistant", "content": result}) | |
| else: | |
| # Extract just the summary | |
| lines = result.split("\n") | |
| title = next((l.replace("## ", "") for l in lines if l.startswith("## ")), "the video") | |
| history.append({ | |
| "role": "assistant", | |
| "content": ( | |
| f"Done! I've analyzed **{title}** and added it to my knowledge base.\n\n" | |
| f"I extracted the transcript and analyzed key visual frames. " | |
| f"You can now ask me questions about this video!\n\n" | |
| f"Try asking:\n" | |
| f"- What are the main topics discussed?\n" | |
| f"- Summarize the key points\n" | |
| f"- What was shown in the video?" | |
| ) | |
| }) | |
| except Exception as e: | |
| history.append({ | |
| "role": "assistant", | |
| "content": f"Sorry, I couldn't analyze that video: {e}" | |
| }) | |
| else: | |
| # Check if we have any analyzed videos | |
| coll = session_state.collection | |
| if coll.count() == 0: | |
| history.append({ | |
| "role": "assistant", | |
| "content": ( | |
| "I don't have any videos analyzed yet. " | |
| "Please paste a YouTube URL and I'll analyze it for you!\n\n" | |
| "Example: `https://youtube.com/watch?v=...`" | |
| ) | |
| }) | |
| else: | |
| # Answer question about videos | |
| response = chat_with_videos(message, history, session_state) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, "", session_state | |
| def get_welcome_message(): | |
| """Get initial welcome message.""" | |
| return [{ | |
| "role": "assistant", | |
| "content": ( | |
| "Welcome to **Video Analyzer**!\n\n" | |
| "**Here's how I work:**\n" | |
| "1. Paste a YouTube URL and I'll analyze it\n" | |
| "2. Ask me questions about the video content\n\n" | |
| "Let's get started - paste a YouTube video URL!" | |
| ) | |
| }] | |
| def create_demo(): | |
| """Create and configure the Gradio demo application.""" | |
| with gr.Blocks(title="Video Analyzer") as demo: | |
| # Per-session state for ChromaDB collection | |
| session_state = gr.State(value=None) | |
| # Centered header with description | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; padding: 20px 0; max-width: 600px; margin: 0 auto;"> | |
| <h1 style="margin-bottom: 12px;">Video Analyzer</h1> | |
| <p style="color: #666; margin-bottom: 16px; line-height: 1.5;"> | |
| Paste a YouTube URL to analyze the video. I'll transcribe the audio, | |
| analyze key frames, and let you ask questions about the content. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Video Analyzer", | |
| height=400, | |
| type="messages", | |
| ) | |
| # Text input row | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="Message", | |
| placeholder="Paste a YouTube URL or ask a question...", | |
| scale=5, | |
| lines=1, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| kb_status = gr.Markdown() | |
| clear_btn = gr.Button("Clear Chat", size="sm") | |
| # Initialize session on load | |
| def init_session(current_state): | |
| """Initialize session state and welcome message.""" | |
| if current_state is None: | |
| current_state = create_session_state() | |
| welcome = get_welcome_message() | |
| stats = get_knowledge_stats(current_state) | |
| return welcome, stats, current_state | |
| # Wire up text chat | |
| send_btn.click( | |
| fn=handle_chat, | |
| inputs=[msg_input, chatbot, session_state], | |
| outputs=[chatbot, msg_input, session_state], | |
| ).then( | |
| fn=lambda ss: get_knowledge_stats(ss), | |
| inputs=[session_state], | |
| outputs=[kb_status], | |
| ) | |
| msg_input.submit( | |
| fn=handle_chat, | |
| inputs=[msg_input, chatbot, session_state], | |
| outputs=[chatbot, msg_input, session_state], | |
| ).then( | |
| fn=lambda ss: get_knowledge_stats(ss), | |
| inputs=[session_state], | |
| outputs=[kb_status], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [], | |
| outputs=[chatbot], | |
| ) | |
| # Initialize session on page load | |
| demo.load( | |
| fn=init_session, | |
| inputs=[session_state], | |
| outputs=[chatbot, kb_status, session_state], | |
| ) | |
| return demo | |
| # Create demo at module level for HuggingFace Spaces | |
| demo = create_demo() | |
| # Monkey-patch to avoid Gradio schema bug with complex types | |
| # The bug occurs when get_api_info() tries to parse additionalProperties: True | |
| _original_get_api_info = demo.get_api_info | |
| def _safe_get_api_info(): | |
| try: | |
| return _original_get_api_info() | |
| except TypeError: | |
| # Return minimal API info to avoid the schema parsing bug | |
| return {"named_endpoints": {}, "unnamed_endpoints": {}} | |
| demo.get_api_info = _safe_get_api_info | |
| if __name__ == "__main__": | |
| demo.launch() | |