wahab5763's picture
Create app.py
1c34698 verified
# app.py - YouTube Video RAG Q&A for Hugging Face Spaces
import gradio as gr
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import os
import re
import groq
from typing import List, Dict, Tuple
import tempfile
# ============================================
# Configuration - Optimized for Token Limits
# ============================================
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Get from Hugging Face Secrets
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE = 300
MAX_CONTEXT_TOKENS = 1500
MAX_RETRIEVAL_CHUNKS = 2
# ============================================
# YouTube Transcript Extraction
# ============================================
class YouTubeTranscriptProcessor:
"""Handles YouTube transcript extraction and processing using new API"""
@staticmethod
def extract_transcript(youtube_url: str) -> Tuple[List[Dict], str]:
"""Extract transcript from YouTube video"""
try:
video_id = YouTubeTranscriptProcessor.extract_video_id(youtube_url)
if not video_id:
return None, "Invalid YouTube URL"
print(f"Processing video ID: {video_id}")
# Create API instance and fetch transcript
ytt_api = YouTubeTranscriptApi()
try:
fetched_transcript = ytt_api.fetch(video_id, languages=['en'])
print("Found English transcript")
except:
print("English transcript not found, trying any available language...")
fetched_transcript = ytt_api.fetch(video_id)
print(f"Found transcript in language: {fetched_transcript.language}")
# Convert to formatted transcript
formatted_transcript = []
for snippet in fetched_transcript.snippets:
formatted_transcript.append({
'text': snippet.text,
'start': snippet.start,
'duration': snippet.duration
})
print(f"Successfully extracted {len(formatted_transcript)} transcript entries")
return formatted_transcript, None
except Exception as e:
return None, f"Error extracting transcript: {str(e)}"
@staticmethod
def extract_video_id(url: str) -> str:
"""Extract video ID from YouTube URL"""
patterns = [
r'(?:youtube\.com\/watch\?v=)([\w-]+)',
r'(?:youtu\.be\/)([\w-]+)',
r'(?:youtube\.com\/embed\/)([\w-]+)',
r'(?:youtube\.com\/v\/)([\w-]+)',
r'(?:youtube\.com\/shorts\/)([\w-]+)'
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None
@staticmethod
def get_full_transcript_text(transcript: List[Dict]) -> str:
"""Convert transcript to readable full text without timestamps"""
# Just join all text entries with spaces
full_text = " ".join([entry['text'] for entry in transcript])
# Clean up extra spaces
full_text = re.sub(r'\s+', ' ', full_text).strip()
# Add line breaks every ~100 characters for better readability
lines = []
words = full_text.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= 100:
current_line.append(word)
current_length += len(word) + 1
else:
lines.append(" ".join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
lines.append(" ".join(current_line))
return "\n".join(lines)
@staticmethod
def chunk_transcript(transcript: List[Dict]) -> List[Dict]:
"""Split transcript into smaller overlapping chunks"""
full_text = " ".join([entry['text'] for entry in transcript])
sentences = re.split(r'(?<=[.!?])\s+', full_text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length <= CHUNK_SIZE:
current_chunk.append(sentence)
current_length += sentence_length
else:
if current_chunk:
chunk_text = " ".join(current_chunk)
chunks.append({
'text': chunk_text,
'chunk_id': len(chunks)
})
overlap_text = " ".join(current_chunk[-2:]) if len(current_chunk) > 2 else " ".join(current_chunk)
current_chunk = [overlap_text, sentence] if overlap_text else [sentence]
current_length = len(overlap_text) + sentence_length if overlap_text else sentence_length
if current_chunk:
chunks.append({
'text': " ".join(current_chunk),
'chunk_id': len(chunks)
})
print(f"Created {len(chunks)} chunks from transcript")
return chunks
# ============================================
# Vector Database Management
# ============================================
class VectorDatabase:
"""Manages FAISS vector database and embeddings"""
def __init__(self):
print("Loading embedding model...")
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
self.index = None
self.chunks = []
# Use temporary files for Hugging Face Spaces
self.index_path = tempfile.NamedTemporaryFile(delete=False, suffix='.bin').name
self.chunks_path = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl').name
def create_embeddings(self, texts: List[str]) -> np.ndarray:
"""Create embeddings for texts"""
print(f"Creating embeddings for {len(texts)} chunks...")
batch_size = 32
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=True)
all_embeddings.append(batch_embeddings)
return np.vstack(all_embeddings)
def build_index(self, chunks: List[Dict]):
"""Build FAISS index from chunks"""
self.chunks = chunks
texts = [chunk['text'] for chunk in chunks]
embeddings = self.create_embeddings(texts)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype('float32'))
self.save()
return True
def search(self, query: str, k: int = MAX_RETRIEVAL_CHUNKS) -> List[Tuple[str, float]]:
"""Search for similar chunks"""
if self.index is None or not self.chunks:
return []
query_embedding = self.embedding_model.encode([query])
distances, indices = self.index.search(query_embedding.astype('float32'), k)
results = []
for i, idx in enumerate(indices[0]):
if idx != -1 and idx < len(self.chunks):
results.append((self.chunks[idx]['text'], float(distances[0][i])))
return results
def save(self):
if self.index:
faiss.write_index(self.index, self.index_path)
with open(self.chunks_path, 'wb') as f:
pickle.dump(self.chunks, f)
print("Database saved successfully")
def load(self):
if os.path.exists(self.index_path) and os.path.exists(self.chunks_path):
self.index = faiss.read_index(self.index_path)
with open(self.chunks_path, 'rb') as f:
self.chunks = pickle.load(f)
print(f"Loaded database with {len(self.chunks)} chunks")
return True
return False
def clear(self):
self.index = None
self.chunks = []
if os.path.exists(self.index_path):
os.remove(self.index_path)
if os.path.exists(self.chunks_path):
os.remove(self.chunks_path)
print("Database cleared")
# ============================================
# RAG Question Answering
# ============================================
class RAGQA:
"""Handles RAG-based question answering using Groq directly"""
def __init__(self):
self.vector_db = VectorDatabase()
self.client = groq.Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
self.current_transcript_text = ""
self.vector_db.load()
def truncate_context(self, context: str, max_tokens: int = MAX_CONTEXT_TOKENS) -> str:
max_chars = max_tokens * 4
if len(context) > max_chars:
return context[:max_chars] + "..."
return context
def process_video(self, youtube_url: str) -> Tuple[str, str, bool]:
"""Process YouTube video and build vector database, return full transcript"""
# Extract transcript
transcript, error = YouTubeTranscriptProcessor.extract_transcript(youtube_url)
if error:
return error, "", False
if not transcript:
return "No transcript data found", "", False
# Get full transcript text without timestamps
self.current_transcript_text = YouTubeTranscriptProcessor.get_full_transcript_text(transcript)
# Chunk transcript for RAG
chunks = YouTubeTranscriptProcessor.chunk_transcript(transcript)
if not chunks:
return "No content to process", self.current_transcript_text, False
# Build vector database
self.vector_db.build_index(chunks)
return f"βœ… Successfully processed {len(chunks)} chunks from video!", self.current_transcript_text, True
def ask_question(self, question: str) -> str:
"""Answer question using RAG with Groq"""
if not GROQ_API_KEY:
return "⚠️ Please set your Groq API key in Hugging Face Secrets."
if self.vector_db.index is None or not self.vector_db.chunks:
return "⚠️ Please load a video transcript first (click 'Get Transcript') before asking questions."
relevant_chunks = self.vector_db.search(question, k=MAX_RETRIEVAL_CHUNKS)
if not relevant_chunks:
return "❓ No relevant information found in the transcript. Please try a different question."
context = "\n\n---\n\n".join([chunk[0] for chunk in relevant_chunks])
context = self.truncate_context(context, MAX_CONTEXT_TOKENS)
system_prompt = """Answer questions based ONLY on the provided transcript context. Be brief (2-3 sentences max). If the answer isn't in the context, say so."""
user_prompt = f"""Context: {context}\n\nQuestion: {question}\n\nAnswer:"""
try:
chat_completion = self.client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
model="llama-3.1-8b-instant",
temperature=0.3,
max_tokens=150
)
return chat_completion.choices[0].message.content
except Exception as e:
if "rate_limit_exceeded" in str(e) or "too large" in str(e):
return "⚠️ Context too large. Please ask a more specific question."
return f"❌ Error: {str(e)}"
def clear_database(self) -> str:
self.vector_db.clear()
self.current_transcript_text = ""
return "πŸ—‘οΈ Database cleared successfully!"
# ============================================
# Gradio UI Application
# ============================================
# Initialize RAG system
rag_system = RAGQA()
def process_youtube_url(youtube_url):
if not youtube_url or youtube_url.strip() == "":
return "❌ Please enter a YouTube URL", "⚠️ Waiting for video...", ""
message, transcript_text, success = rag_system.process_video(youtube_url)
if success:
return message, "βœ… Ready for questions!", transcript_text
else:
return message, "❌ Failed to process video", ""
def answer_question(question, history):
if not question or question.strip() == "":
return history
answer = rag_system.ask_question(question)
history = history or []
history.append((question, answer))
return history
def clear_everything():
message = rag_system.clear_database()
return message, "⚠️ Waiting for video...", "", []
# Create Gradio interface
with gr.Blocks(title="πŸŽ₯ YouTube Video RAG Q&A", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ“š YouTube Video Q&A with RAG
### Extract transcript and ask questions about any YouTube video!
**How it works:**
1. Enter a YouTube URL
2. Click "Get Transcript" to extract and process the video transcript
3. Ask questions about the video content
4. Get accurate answers based solely on the transcript
**Note:** Make sure the video has captions/transcripts enabled.
""")
with gr.Row():
with gr.Column(scale=3):
youtube_url = gr.Textbox(
label="πŸ”— YouTube URL",
placeholder="https://www.youtube.com/watch?v=...",
lines=1
)
with gr.Column(scale=1):
process_btn = gr.Button("🎬 Get Transcript", variant="primary", size="lg")
with gr.Row():
status_text = gr.Textbox(label="πŸ“Š Status", interactive=False, lines=2)
qa_status = gr.Textbox(label="QA Status", interactive=False, lines=1, value="⚠️ Waiting for video...")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Complete Transcript")
transcript_display = gr.Textbox(
label="",
interactive=False,
lines=25,
max_lines=25,
placeholder="Transcript will appear here after processing..."
)
with gr.Column(scale=1):
gr.Markdown("### πŸ’¬ Ask Questions")
chatbot = gr.Chatbot(
label="Chat",
height=400,
bubble_full_width=False,
avatar_images=(None, "πŸ€–")
)
with gr.Row():
question = gr.Textbox(
label="Your Question",
placeholder="Ask about the video...",
lines=2,
scale=4
)
submit_btn = gr.Button("Ask", variant="primary", scale=1)
with gr.Row():
clear_chat_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", size="sm")
clear_all_btn = gr.Button("πŸ”„ Clear All", variant="stop", size="sm")
# Event handlers
process_btn.click(
process_youtube_url,
inputs=[youtube_url],
outputs=[status_text, qa_status, transcript_display]
)
submit_btn.click(
answer_question,
inputs=[question, chatbot],
outputs=[chatbot]
).then(
lambda: "", None, [question]
)
clear_chat_btn.click(
lambda: [], None, [chatbot]
)
clear_all_btn.click(
clear_everything,
outputs=[status_text, qa_status, transcript_display, chatbot]
)
question.submit(
answer_question,
inputs=[question, chatbot],
outputs=[chatbot]
).then(
lambda: "", None, [question]
)
# Launch the app
if __name__ == "__main__":
demo.launch()