Spaces:
Runtime error
Runtime error
| # app.py - YouTube Video RAG Q&A for Hugging Face Spaces | |
| import gradio as gr | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import pickle | |
| import os | |
| import re | |
| import groq | |
| from typing import List, Dict, Tuple | |
| import tempfile | |
| # ============================================ | |
| # Configuration - Optimized for Token Limits | |
| # ============================================ | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Get from Hugging Face Secrets | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| CHUNK_SIZE = 300 | |
| MAX_CONTEXT_TOKENS = 1500 | |
| MAX_RETRIEVAL_CHUNKS = 2 | |
| # ============================================ | |
| # YouTube Transcript Extraction | |
| # ============================================ | |
| class YouTubeTranscriptProcessor: | |
| """Handles YouTube transcript extraction and processing using new API""" | |
| def extract_transcript(youtube_url: str) -> Tuple[List[Dict], str]: | |
| """Extract transcript from YouTube video""" | |
| try: | |
| video_id = YouTubeTranscriptProcessor.extract_video_id(youtube_url) | |
| if not video_id: | |
| return None, "Invalid YouTube URL" | |
| print(f"Processing video ID: {video_id}") | |
| # Create API instance and fetch transcript | |
| ytt_api = YouTubeTranscriptApi() | |
| try: | |
| fetched_transcript = ytt_api.fetch(video_id, languages=['en']) | |
| print("Found English transcript") | |
| except: | |
| print("English transcript not found, trying any available language...") | |
| fetched_transcript = ytt_api.fetch(video_id) | |
| print(f"Found transcript in language: {fetched_transcript.language}") | |
| # Convert to formatted transcript | |
| formatted_transcript = [] | |
| for snippet in fetched_transcript.snippets: | |
| formatted_transcript.append({ | |
| 'text': snippet.text, | |
| 'start': snippet.start, | |
| 'duration': snippet.duration | |
| }) | |
| print(f"Successfully extracted {len(formatted_transcript)} transcript entries") | |
| return formatted_transcript, None | |
| except Exception as e: | |
| return None, f"Error extracting transcript: {str(e)}" | |
| def extract_video_id(url: str) -> str: | |
| """Extract video ID from YouTube URL""" | |
| patterns = [ | |
| r'(?:youtube\.com\/watch\?v=)([\w-]+)', | |
| r'(?:youtu\.be\/)([\w-]+)', | |
| r'(?:youtube\.com\/embed\/)([\w-]+)', | |
| r'(?:youtube\.com\/v\/)([\w-]+)', | |
| r'(?:youtube\.com\/shorts\/)([\w-]+)' | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, url) | |
| if match: | |
| return match.group(1) | |
| return None | |
| def get_full_transcript_text(transcript: List[Dict]) -> str: | |
| """Convert transcript to readable full text without timestamps""" | |
| # Just join all text entries with spaces | |
| full_text = " ".join([entry['text'] for entry in transcript]) | |
| # Clean up extra spaces | |
| full_text = re.sub(r'\s+', ' ', full_text).strip() | |
| # Add line breaks every ~100 characters for better readability | |
| lines = [] | |
| words = full_text.split() | |
| current_line = [] | |
| current_length = 0 | |
| for word in words: | |
| if current_length + len(word) + 1 <= 100: | |
| current_line.append(word) | |
| current_length += len(word) + 1 | |
| else: | |
| lines.append(" ".join(current_line)) | |
| current_line = [word] | |
| current_length = len(word) | |
| if current_line: | |
| lines.append(" ".join(current_line)) | |
| return "\n".join(lines) | |
| def chunk_transcript(transcript: List[Dict]) -> List[Dict]: | |
| """Split transcript into smaller overlapping chunks""" | |
| full_text = " ".join([entry['text'] for entry in transcript]) | |
| sentences = re.split(r'(?<=[.!?])\s+', full_text) | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for sentence in sentences: | |
| sentence_length = len(sentence) | |
| if current_length + sentence_length <= CHUNK_SIZE: | |
| current_chunk.append(sentence) | |
| current_length += sentence_length | |
| else: | |
| if current_chunk: | |
| chunk_text = " ".join(current_chunk) | |
| chunks.append({ | |
| 'text': chunk_text, | |
| 'chunk_id': len(chunks) | |
| }) | |
| overlap_text = " ".join(current_chunk[-2:]) if len(current_chunk) > 2 else " ".join(current_chunk) | |
| current_chunk = [overlap_text, sentence] if overlap_text else [sentence] | |
| current_length = len(overlap_text) + sentence_length if overlap_text else sentence_length | |
| if current_chunk: | |
| chunks.append({ | |
| 'text': " ".join(current_chunk), | |
| 'chunk_id': len(chunks) | |
| }) | |
| print(f"Created {len(chunks)} chunks from transcript") | |
| return chunks | |
| # ============================================ | |
| # Vector Database Management | |
| # ============================================ | |
| class VectorDatabase: | |
| """Manages FAISS vector database and embeddings""" | |
| def __init__(self): | |
| print("Loading embedding model...") | |
| self.embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
| self.index = None | |
| self.chunks = [] | |
| # Use temporary files for Hugging Face Spaces | |
| self.index_path = tempfile.NamedTemporaryFile(delete=False, suffix='.bin').name | |
| self.chunks_path = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl').name | |
| def create_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """Create embeddings for texts""" | |
| print(f"Creating embeddings for {len(texts)} chunks...") | |
| batch_size = 32 | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i+batch_size] | |
| batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=True) | |
| all_embeddings.append(batch_embeddings) | |
| return np.vstack(all_embeddings) | |
| def build_index(self, chunks: List[Dict]): | |
| """Build FAISS index from chunks""" | |
| self.chunks = chunks | |
| texts = [chunk['text'] for chunk in chunks] | |
| embeddings = self.create_embeddings(texts) | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dimension) | |
| self.index.add(embeddings.astype('float32')) | |
| self.save() | |
| return True | |
| def search(self, query: str, k: int = MAX_RETRIEVAL_CHUNKS) -> List[Tuple[str, float]]: | |
| """Search for similar chunks""" | |
| if self.index is None or not self.chunks: | |
| return [] | |
| query_embedding = self.embedding_model.encode([query]) | |
| distances, indices = self.index.search(query_embedding.astype('float32'), k) | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx != -1 and idx < len(self.chunks): | |
| results.append((self.chunks[idx]['text'], float(distances[0][i]))) | |
| return results | |
| def save(self): | |
| if self.index: | |
| faiss.write_index(self.index, self.index_path) | |
| with open(self.chunks_path, 'wb') as f: | |
| pickle.dump(self.chunks, f) | |
| print("Database saved successfully") | |
| def load(self): | |
| if os.path.exists(self.index_path) and os.path.exists(self.chunks_path): | |
| self.index = faiss.read_index(self.index_path) | |
| with open(self.chunks_path, 'rb') as f: | |
| self.chunks = pickle.load(f) | |
| print(f"Loaded database with {len(self.chunks)} chunks") | |
| return True | |
| return False | |
| def clear(self): | |
| self.index = None | |
| self.chunks = [] | |
| if os.path.exists(self.index_path): | |
| os.remove(self.index_path) | |
| if os.path.exists(self.chunks_path): | |
| os.remove(self.chunks_path) | |
| print("Database cleared") | |
| # ============================================ | |
| # RAG Question Answering | |
| # ============================================ | |
| class RAGQA: | |
| """Handles RAG-based question answering using Groq directly""" | |
| def __init__(self): | |
| self.vector_db = VectorDatabase() | |
| self.client = groq.Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None | |
| self.current_transcript_text = "" | |
| self.vector_db.load() | |
| def truncate_context(self, context: str, max_tokens: int = MAX_CONTEXT_TOKENS) -> str: | |
| max_chars = max_tokens * 4 | |
| if len(context) > max_chars: | |
| return context[:max_chars] + "..." | |
| return context | |
| def process_video(self, youtube_url: str) -> Tuple[str, str, bool]: | |
| """Process YouTube video and build vector database, return full transcript""" | |
| # Extract transcript | |
| transcript, error = YouTubeTranscriptProcessor.extract_transcript(youtube_url) | |
| if error: | |
| return error, "", False | |
| if not transcript: | |
| return "No transcript data found", "", False | |
| # Get full transcript text without timestamps | |
| self.current_transcript_text = YouTubeTranscriptProcessor.get_full_transcript_text(transcript) | |
| # Chunk transcript for RAG | |
| chunks = YouTubeTranscriptProcessor.chunk_transcript(transcript) | |
| if not chunks: | |
| return "No content to process", self.current_transcript_text, False | |
| # Build vector database | |
| self.vector_db.build_index(chunks) | |
| return f"β Successfully processed {len(chunks)} chunks from video!", self.current_transcript_text, True | |
| def ask_question(self, question: str) -> str: | |
| """Answer question using RAG with Groq""" | |
| if not GROQ_API_KEY: | |
| return "β οΈ Please set your Groq API key in Hugging Face Secrets." | |
| if self.vector_db.index is None or not self.vector_db.chunks: | |
| return "β οΈ Please load a video transcript first (click 'Get Transcript') before asking questions." | |
| relevant_chunks = self.vector_db.search(question, k=MAX_RETRIEVAL_CHUNKS) | |
| if not relevant_chunks: | |
| return "β No relevant information found in the transcript. Please try a different question." | |
| context = "\n\n---\n\n".join([chunk[0] for chunk in relevant_chunks]) | |
| context = self.truncate_context(context, MAX_CONTEXT_TOKENS) | |
| system_prompt = """Answer questions based ONLY on the provided transcript context. Be brief (2-3 sentences max). If the answer isn't in the context, say so.""" | |
| user_prompt = f"""Context: {context}\n\nQuestion: {question}\n\nAnswer:""" | |
| try: | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| model="llama-3.1-8b-instant", | |
| temperature=0.3, | |
| max_tokens=150 | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| if "rate_limit_exceeded" in str(e) or "too large" in str(e): | |
| return "β οΈ Context too large. Please ask a more specific question." | |
| return f"β Error: {str(e)}" | |
| def clear_database(self) -> str: | |
| self.vector_db.clear() | |
| self.current_transcript_text = "" | |
| return "ποΈ Database cleared successfully!" | |
| # ============================================ | |
| # Gradio UI Application | |
| # ============================================ | |
| # Initialize RAG system | |
| rag_system = RAGQA() | |
| def process_youtube_url(youtube_url): | |
| if not youtube_url or youtube_url.strip() == "": | |
| return "β Please enter a YouTube URL", "β οΈ Waiting for video...", "" | |
| message, transcript_text, success = rag_system.process_video(youtube_url) | |
| if success: | |
| return message, "β Ready for questions!", transcript_text | |
| else: | |
| return message, "β Failed to process video", "" | |
| def answer_question(question, history): | |
| if not question or question.strip() == "": | |
| return history | |
| answer = rag_system.ask_question(question) | |
| history = history or [] | |
| history.append((question, answer)) | |
| return history | |
| def clear_everything(): | |
| message = rag_system.clear_database() | |
| return message, "β οΈ Waiting for video...", "", [] | |
| # Create Gradio interface | |
| with gr.Blocks(title="π₯ YouTube Video RAG Q&A", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π YouTube Video Q&A with RAG | |
| ### Extract transcript and ask questions about any YouTube video! | |
| **How it works:** | |
| 1. Enter a YouTube URL | |
| 2. Click "Get Transcript" to extract and process the video transcript | |
| 3. Ask questions about the video content | |
| 4. Get accurate answers based solely on the transcript | |
| **Note:** Make sure the video has captions/transcripts enabled. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| youtube_url = gr.Textbox( | |
| label="π YouTube URL", | |
| placeholder="https://www.youtube.com/watch?v=...", | |
| lines=1 | |
| ) | |
| with gr.Column(scale=1): | |
| process_btn = gr.Button("π¬ Get Transcript", variant="primary", size="lg") | |
| with gr.Row(): | |
| status_text = gr.Textbox(label="π Status", interactive=False, lines=2) | |
| qa_status = gr.Textbox(label="QA Status", interactive=False, lines=1, value="β οΈ Waiting for video...") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Complete Transcript") | |
| transcript_display = gr.Textbox( | |
| label="", | |
| interactive=False, | |
| lines=25, | |
| max_lines=25, | |
| placeholder="Transcript will appear here after processing..." | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π¬ Ask Questions") | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| height=400, | |
| bubble_full_width=False, | |
| avatar_images=(None, "π€") | |
| ) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about the video...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Ask", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_chat_btn = gr.Button("ποΈ Clear Chat", variant="secondary", size="sm") | |
| clear_all_btn = gr.Button("π Clear All", variant="stop", size="sm") | |
| # Event handlers | |
| process_btn.click( | |
| process_youtube_url, | |
| inputs=[youtube_url], | |
| outputs=[status_text, qa_status, transcript_display] | |
| ) | |
| submit_btn.click( | |
| answer_question, | |
| inputs=[question, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", None, [question] | |
| ) | |
| clear_chat_btn.click( | |
| lambda: [], None, [chatbot] | |
| ) | |
| clear_all_btn.click( | |
| clear_everything, | |
| outputs=[status_text, qa_status, transcript_display, chatbot] | |
| ) | |
| question.submit( | |
| answer_question, | |
| inputs=[question, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", None, [question] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |