wahab5763 commited on
Commit
1c34698
Β·
verified Β·
1 Parent(s): b3926a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +447 -0
app.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - YouTube Video RAG Q&A for Hugging Face Spaces
2
+
3
+ import gradio as gr
4
+ from youtube_transcript_api import YouTubeTranscriptApi
5
+ from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound
6
+ from sentence_transformers import SentenceTransformer
7
+ import faiss
8
+ import numpy as np
9
+ import pickle
10
+ import os
11
+ import re
12
+ import groq
13
+ from typing import List, Dict, Tuple
14
+ import tempfile
15
+
16
+ # ============================================
17
+ # Configuration - Optimized for Token Limits
18
+ # ============================================
19
+
20
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Get from Hugging Face Secrets
21
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
22
+ CHUNK_SIZE = 300
23
+ MAX_CONTEXT_TOKENS = 1500
24
+ MAX_RETRIEVAL_CHUNKS = 2
25
+
26
+ # ============================================
27
+ # YouTube Transcript Extraction
28
+ # ============================================
29
+
30
+ class YouTubeTranscriptProcessor:
31
+ """Handles YouTube transcript extraction and processing using new API"""
32
+
33
+ @staticmethod
34
+ def extract_transcript(youtube_url: str) -> Tuple[List[Dict], str]:
35
+ """Extract transcript from YouTube video"""
36
+ try:
37
+ video_id = YouTubeTranscriptProcessor.extract_video_id(youtube_url)
38
+ if not video_id:
39
+ return None, "Invalid YouTube URL"
40
+
41
+ print(f"Processing video ID: {video_id}")
42
+
43
+ # Create API instance and fetch transcript
44
+ ytt_api = YouTubeTranscriptApi()
45
+
46
+ try:
47
+ fetched_transcript = ytt_api.fetch(video_id, languages=['en'])
48
+ print("Found English transcript")
49
+ except:
50
+ print("English transcript not found, trying any available language...")
51
+ fetched_transcript = ytt_api.fetch(video_id)
52
+ print(f"Found transcript in language: {fetched_transcript.language}")
53
+
54
+ # Convert to formatted transcript
55
+ formatted_transcript = []
56
+ for snippet in fetched_transcript.snippets:
57
+ formatted_transcript.append({
58
+ 'text': snippet.text,
59
+ 'start': snippet.start,
60
+ 'duration': snippet.duration
61
+ })
62
+
63
+ print(f"Successfully extracted {len(formatted_transcript)} transcript entries")
64
+ return formatted_transcript, None
65
+
66
+ except Exception as e:
67
+ return None, f"Error extracting transcript: {str(e)}"
68
+
69
+ @staticmethod
70
+ def extract_video_id(url: str) -> str:
71
+ """Extract video ID from YouTube URL"""
72
+ patterns = [
73
+ r'(?:youtube\.com\/watch\?v=)([\w-]+)',
74
+ r'(?:youtu\.be\/)([\w-]+)',
75
+ r'(?:youtube\.com\/embed\/)([\w-]+)',
76
+ r'(?:youtube\.com\/v\/)([\w-]+)',
77
+ r'(?:youtube\.com\/shorts\/)([\w-]+)'
78
+ ]
79
+
80
+ for pattern in patterns:
81
+ match = re.search(pattern, url)
82
+ if match:
83
+ return match.group(1)
84
+ return None
85
+
86
+ @staticmethod
87
+ def get_full_transcript_text(transcript: List[Dict]) -> str:
88
+ """Convert transcript to readable full text without timestamps"""
89
+ # Just join all text entries with spaces
90
+ full_text = " ".join([entry['text'] for entry in transcript])
91
+
92
+ # Clean up extra spaces
93
+ full_text = re.sub(r'\s+', ' ', full_text).strip()
94
+
95
+ # Add line breaks every ~100 characters for better readability
96
+ lines = []
97
+ words = full_text.split()
98
+ current_line = []
99
+ current_length = 0
100
+
101
+ for word in words:
102
+ if current_length + len(word) + 1 <= 100:
103
+ current_line.append(word)
104
+ current_length += len(word) + 1
105
+ else:
106
+ lines.append(" ".join(current_line))
107
+ current_line = [word]
108
+ current_length = len(word)
109
+
110
+ if current_line:
111
+ lines.append(" ".join(current_line))
112
+
113
+ return "\n".join(lines)
114
+
115
+ @staticmethod
116
+ def chunk_transcript(transcript: List[Dict]) -> List[Dict]:
117
+ """Split transcript into smaller overlapping chunks"""
118
+ full_text = " ".join([entry['text'] for entry in transcript])
119
+ sentences = re.split(r'(?<=[.!?])\s+', full_text)
120
+
121
+ chunks = []
122
+ current_chunk = []
123
+ current_length = 0
124
+
125
+ for sentence in sentences:
126
+ sentence_length = len(sentence)
127
+
128
+ if current_length + sentence_length <= CHUNK_SIZE:
129
+ current_chunk.append(sentence)
130
+ current_length += sentence_length
131
+ else:
132
+ if current_chunk:
133
+ chunk_text = " ".join(current_chunk)
134
+ chunks.append({
135
+ 'text': chunk_text,
136
+ 'chunk_id': len(chunks)
137
+ })
138
+
139
+ overlap_text = " ".join(current_chunk[-2:]) if len(current_chunk) > 2 else " ".join(current_chunk)
140
+ current_chunk = [overlap_text, sentence] if overlap_text else [sentence]
141
+ current_length = len(overlap_text) + sentence_length if overlap_text else sentence_length
142
+
143
+ if current_chunk:
144
+ chunks.append({
145
+ 'text': " ".join(current_chunk),
146
+ 'chunk_id': len(chunks)
147
+ })
148
+
149
+ print(f"Created {len(chunks)} chunks from transcript")
150
+ return chunks
151
+
152
+ # ============================================
153
+ # Vector Database Management
154
+ # ============================================
155
+
156
+ class VectorDatabase:
157
+ """Manages FAISS vector database and embeddings"""
158
+
159
+ def __init__(self):
160
+ print("Loading embedding model...")
161
+ self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
162
+ self.index = None
163
+ self.chunks = []
164
+ # Use temporary files for Hugging Face Spaces
165
+ self.index_path = tempfile.NamedTemporaryFile(delete=False, suffix='.bin').name
166
+ self.chunks_path = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl').name
167
+
168
+ def create_embeddings(self, texts: List[str]) -> np.ndarray:
169
+ """Create embeddings for texts"""
170
+ print(f"Creating embeddings for {len(texts)} chunks...")
171
+ batch_size = 32
172
+ all_embeddings = []
173
+ for i in range(0, len(texts), batch_size):
174
+ batch = texts[i:i+batch_size]
175
+ batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=True)
176
+ all_embeddings.append(batch_embeddings)
177
+
178
+ return np.vstack(all_embeddings)
179
+
180
+ def build_index(self, chunks: List[Dict]):
181
+ """Build FAISS index from chunks"""
182
+ self.chunks = chunks
183
+ texts = [chunk['text'] for chunk in chunks]
184
+ embeddings = self.create_embeddings(texts)
185
+
186
+ dimension = embeddings.shape[1]
187
+ self.index = faiss.IndexFlatL2(dimension)
188
+ self.index.add(embeddings.astype('float32'))
189
+
190
+ self.save()
191
+ return True
192
+
193
+ def search(self, query: str, k: int = MAX_RETRIEVAL_CHUNKS) -> List[Tuple[str, float]]:
194
+ """Search for similar chunks"""
195
+ if self.index is None or not self.chunks:
196
+ return []
197
+
198
+ query_embedding = self.embedding_model.encode([query])
199
+ distances, indices = self.index.search(query_embedding.astype('float32'), k)
200
+
201
+ results = []
202
+ for i, idx in enumerate(indices[0]):
203
+ if idx != -1 and idx < len(self.chunks):
204
+ results.append((self.chunks[idx]['text'], float(distances[0][i])))
205
+
206
+ return results
207
+
208
+ def save(self):
209
+ if self.index:
210
+ faiss.write_index(self.index, self.index_path)
211
+ with open(self.chunks_path, 'wb') as f:
212
+ pickle.dump(self.chunks, f)
213
+ print("Database saved successfully")
214
+
215
+ def load(self):
216
+ if os.path.exists(self.index_path) and os.path.exists(self.chunks_path):
217
+ self.index = faiss.read_index(self.index_path)
218
+ with open(self.chunks_path, 'rb') as f:
219
+ self.chunks = pickle.load(f)
220
+ print(f"Loaded database with {len(self.chunks)} chunks")
221
+ return True
222
+ return False
223
+
224
+ def clear(self):
225
+ self.index = None
226
+ self.chunks = []
227
+ if os.path.exists(self.index_path):
228
+ os.remove(self.index_path)
229
+ if os.path.exists(self.chunks_path):
230
+ os.remove(self.chunks_path)
231
+ print("Database cleared")
232
+
233
+ # ============================================
234
+ # RAG Question Answering
235
+ # ============================================
236
+
237
+ class RAGQA:
238
+ """Handles RAG-based question answering using Groq directly"""
239
+
240
+ def __init__(self):
241
+ self.vector_db = VectorDatabase()
242
+ self.client = groq.Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
243
+ self.current_transcript_text = ""
244
+ self.vector_db.load()
245
+
246
+ def truncate_context(self, context: str, max_tokens: int = MAX_CONTEXT_TOKENS) -> str:
247
+ max_chars = max_tokens * 4
248
+ if len(context) > max_chars:
249
+ return context[:max_chars] + "..."
250
+ return context
251
+
252
+ def process_video(self, youtube_url: str) -> Tuple[str, str, bool]:
253
+ """Process YouTube video and build vector database, return full transcript"""
254
+ # Extract transcript
255
+ transcript, error = YouTubeTranscriptProcessor.extract_transcript(youtube_url)
256
+ if error:
257
+ return error, "", False
258
+
259
+ if not transcript:
260
+ return "No transcript data found", "", False
261
+
262
+ # Get full transcript text without timestamps
263
+ self.current_transcript_text = YouTubeTranscriptProcessor.get_full_transcript_text(transcript)
264
+
265
+ # Chunk transcript for RAG
266
+ chunks = YouTubeTranscriptProcessor.chunk_transcript(transcript)
267
+
268
+ if not chunks:
269
+ return "No content to process", self.current_transcript_text, False
270
+
271
+ # Build vector database
272
+ self.vector_db.build_index(chunks)
273
+
274
+ return f"βœ… Successfully processed {len(chunks)} chunks from video!", self.current_transcript_text, True
275
+
276
+ def ask_question(self, question: str) -> str:
277
+ """Answer question using RAG with Groq"""
278
+ if not GROQ_API_KEY:
279
+ return "⚠️ Please set your Groq API key in Hugging Face Secrets."
280
+
281
+ if self.vector_db.index is None or not self.vector_db.chunks:
282
+ return "⚠️ Please load a video transcript first (click 'Get Transcript') before asking questions."
283
+
284
+ relevant_chunks = self.vector_db.search(question, k=MAX_RETRIEVAL_CHUNKS)
285
+
286
+ if not relevant_chunks:
287
+ return "❓ No relevant information found in the transcript. Please try a different question."
288
+
289
+ context = "\n\n---\n\n".join([chunk[0] for chunk in relevant_chunks])
290
+ context = self.truncate_context(context, MAX_CONTEXT_TOKENS)
291
+
292
+ system_prompt = """Answer questions based ONLY on the provided transcript context. Be brief (2-3 sentences max). If the answer isn't in the context, say so."""
293
+ user_prompt = f"""Context: {context}\n\nQuestion: {question}\n\nAnswer:"""
294
+
295
+ try:
296
+ chat_completion = self.client.chat.completions.create(
297
+ messages=[
298
+ {"role": "system", "content": system_prompt},
299
+ {"role": "user", "content": user_prompt}
300
+ ],
301
+ model="llama-3.1-8b-instant",
302
+ temperature=0.3,
303
+ max_tokens=150
304
+ )
305
+
306
+ return chat_completion.choices[0].message.content
307
+
308
+ except Exception as e:
309
+ if "rate_limit_exceeded" in str(e) or "too large" in str(e):
310
+ return "⚠️ Context too large. Please ask a more specific question."
311
+ return f"❌ Error: {str(e)}"
312
+
313
+ def clear_database(self) -> str:
314
+ self.vector_db.clear()
315
+ self.current_transcript_text = ""
316
+ return "πŸ—‘οΈ Database cleared successfully!"
317
+
318
+ # ============================================
319
+ # Gradio UI Application
320
+ # ============================================
321
+
322
+ # Initialize RAG system
323
+ rag_system = RAGQA()
324
+
325
+ def process_youtube_url(youtube_url):
326
+ if not youtube_url or youtube_url.strip() == "":
327
+ return "❌ Please enter a YouTube URL", "⚠️ Waiting for video...", ""
328
+
329
+ message, transcript_text, success = rag_system.process_video(youtube_url)
330
+ if success:
331
+ return message, "βœ… Ready for questions!", transcript_text
332
+ else:
333
+ return message, "❌ Failed to process video", ""
334
+
335
+ def answer_question(question, history):
336
+ if not question or question.strip() == "":
337
+ return history
338
+
339
+ answer = rag_system.ask_question(question)
340
+ history = history or []
341
+ history.append((question, answer))
342
+ return history
343
+
344
+ def clear_everything():
345
+ message = rag_system.clear_database()
346
+ return message, "⚠️ Waiting for video...", "", []
347
+
348
+ # Create Gradio interface
349
+ with gr.Blocks(title="πŸŽ₯ YouTube Video RAG Q&A", theme=gr.themes.Soft()) as demo:
350
+ gr.Markdown("""
351
+ # πŸ“š YouTube Video Q&A with RAG
352
+ ### Extract transcript and ask questions about any YouTube video!
353
+
354
+ **How it works:**
355
+ 1. Enter a YouTube URL
356
+ 2. Click "Get Transcript" to extract and process the video transcript
357
+ 3. Ask questions about the video content
358
+ 4. Get accurate answers based solely on the transcript
359
+
360
+ **Note:** Make sure the video has captions/transcripts enabled.
361
+ """)
362
+
363
+ with gr.Row():
364
+ with gr.Column(scale=3):
365
+ youtube_url = gr.Textbox(
366
+ label="πŸ”— YouTube URL",
367
+ placeholder="https://www.youtube.com/watch?v=...",
368
+ lines=1
369
+ )
370
+
371
+ with gr.Column(scale=1):
372
+ process_btn = gr.Button("🎬 Get Transcript", variant="primary", size="lg")
373
+
374
+ with gr.Row():
375
+ status_text = gr.Textbox(label="πŸ“Š Status", interactive=False, lines=2)
376
+ qa_status = gr.Textbox(label="QA Status", interactive=False, lines=1, value="⚠️ Waiting for video...")
377
+
378
+ gr.Markdown("---")
379
+
380
+ with gr.Row():
381
+ with gr.Column(scale=1):
382
+ gr.Markdown("### πŸ“ Complete Transcript")
383
+ transcript_display = gr.Textbox(
384
+ label="",
385
+ interactive=False,
386
+ lines=25,
387
+ max_lines=25,
388
+ placeholder="Transcript will appear here after processing..."
389
+ )
390
+
391
+ with gr.Column(scale=1):
392
+ gr.Markdown("### πŸ’¬ Ask Questions")
393
+ chatbot = gr.Chatbot(
394
+ label="Chat",
395
+ height=400,
396
+ bubble_full_width=False,
397
+ avatar_images=(None, "πŸ€–")
398
+ )
399
+
400
+ with gr.Row():
401
+ question = gr.Textbox(
402
+ label="Your Question",
403
+ placeholder="Ask about the video...",
404
+ lines=2,
405
+ scale=4
406
+ )
407
+ submit_btn = gr.Button("Ask", variant="primary", scale=1)
408
+
409
+ with gr.Row():
410
+ clear_chat_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", size="sm")
411
+ clear_all_btn = gr.Button("πŸ”„ Clear All", variant="stop", size="sm")
412
+
413
+ # Event handlers
414
+ process_btn.click(
415
+ process_youtube_url,
416
+ inputs=[youtube_url],
417
+ outputs=[status_text, qa_status, transcript_display]
418
+ )
419
+
420
+ submit_btn.click(
421
+ answer_question,
422
+ inputs=[question, chatbot],
423
+ outputs=[chatbot]
424
+ ).then(
425
+ lambda: "", None, [question]
426
+ )
427
+
428
+ clear_chat_btn.click(
429
+ lambda: [], None, [chatbot]
430
+ )
431
+
432
+ clear_all_btn.click(
433
+ clear_everything,
434
+ outputs=[status_text, qa_status, transcript_display, chatbot]
435
+ )
436
+
437
+ question.submit(
438
+ answer_question,
439
+ inputs=[question, chatbot],
440
+ outputs=[chatbot]
441
+ ).then(
442
+ lambda: "", None, [question]
443
+ )
444
+
445
+ # Launch the app
446
+ if __name__ == "__main__":
447
+ demo.launch()