File size: 16,343 Bytes
1c34698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
# app.py - YouTube Video RAG Q&A for Hugging Face Spaces

import gradio as gr
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import os
import re
import groq
from typing import List, Dict, Tuple
import tempfile

# ============================================
# Configuration - Optimized for Token Limits
# ============================================

GROQ_API_KEY = os.getenv("GROQ_API_KEY")  # Get from Hugging Face Secrets
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE = 300
MAX_CONTEXT_TOKENS = 1500
MAX_RETRIEVAL_CHUNKS = 2

# ============================================
# YouTube Transcript Extraction
# ============================================

class YouTubeTranscriptProcessor:
    """Handles YouTube transcript extraction and processing using new API"""
    
    @staticmethod
    def extract_transcript(youtube_url: str) -> Tuple[List[Dict], str]:
        """Extract transcript from YouTube video"""
        try:
            video_id = YouTubeTranscriptProcessor.extract_video_id(youtube_url)
            if not video_id:
                return None, "Invalid YouTube URL"
            
            print(f"Processing video ID: {video_id}")
            
            # Create API instance and fetch transcript
            ytt_api = YouTubeTranscriptApi()
            
            try:
                fetched_transcript = ytt_api.fetch(video_id, languages=['en'])
                print("Found English transcript")
            except:
                print("English transcript not found, trying any available language...")
                fetched_transcript = ytt_api.fetch(video_id)
                print(f"Found transcript in language: {fetched_transcript.language}")
            
            # Convert to formatted transcript
            formatted_transcript = []
            for snippet in fetched_transcript.snippets:
                formatted_transcript.append({
                    'text': snippet.text,
                    'start': snippet.start,
                    'duration': snippet.duration
                })
            
            print(f"Successfully extracted {len(formatted_transcript)} transcript entries")
            return formatted_transcript, None
            
        except Exception as e:
            return None, f"Error extracting transcript: {str(e)}"
    
    @staticmethod
    def extract_video_id(url: str) -> str:
        """Extract video ID from YouTube URL"""
        patterns = [
            r'(?:youtube\.com\/watch\?v=)([\w-]+)',
            r'(?:youtu\.be\/)([\w-]+)',
            r'(?:youtube\.com\/embed\/)([\w-]+)',
            r'(?:youtube\.com\/v\/)([\w-]+)',
            r'(?:youtube\.com\/shorts\/)([\w-]+)'
        ]
        
        for pattern in patterns:
            match = re.search(pattern, url)
            if match:
                return match.group(1)
        return None
    
    @staticmethod
    def get_full_transcript_text(transcript: List[Dict]) -> str:
        """Convert transcript to readable full text without timestamps"""
        # Just join all text entries with spaces
        full_text = " ".join([entry['text'] for entry in transcript])
        
        # Clean up extra spaces
        full_text = re.sub(r'\s+', ' ', full_text).strip()
        
        # Add line breaks every ~100 characters for better readability
        lines = []
        words = full_text.split()
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + 1 <= 100:
                current_line.append(word)
                current_length += len(word) + 1
            else:
                lines.append(" ".join(current_line))
                current_line = [word]
                current_length = len(word)
        
        if current_line:
            lines.append(" ".join(current_line))
        
        return "\n".join(lines)
    
    @staticmethod
    def chunk_transcript(transcript: List[Dict]) -> List[Dict]:
        """Split transcript into smaller overlapping chunks"""
        full_text = " ".join([entry['text'] for entry in transcript])
        sentences = re.split(r'(?<=[.!?])\s+', full_text)
        
        chunks = []
        current_chunk = []
        current_length = 0
        
        for sentence in sentences:
            sentence_length = len(sentence)
            
            if current_length + sentence_length <= CHUNK_SIZE:
                current_chunk.append(sentence)
                current_length += sentence_length
            else:
                if current_chunk:
                    chunk_text = " ".join(current_chunk)
                    chunks.append({
                        'text': chunk_text,
                        'chunk_id': len(chunks)
                    })
                
                overlap_text = " ".join(current_chunk[-2:]) if len(current_chunk) > 2 else " ".join(current_chunk)
                current_chunk = [overlap_text, sentence] if overlap_text else [sentence]
                current_length = len(overlap_text) + sentence_length if overlap_text else sentence_length
        
        if current_chunk:
            chunks.append({
                'text': " ".join(current_chunk),
                'chunk_id': len(chunks)
            })
        
        print(f"Created {len(chunks)} chunks from transcript")
        return chunks

# ============================================
# Vector Database Management
# ============================================

class VectorDatabase:
    """Manages FAISS vector database and embeddings"""
    
    def __init__(self):
        print("Loading embedding model...")
        self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
        self.index = None
        self.chunks = []
        # Use temporary files for Hugging Face Spaces
        self.index_path = tempfile.NamedTemporaryFile(delete=False, suffix='.bin').name
        self.chunks_path = tempfile.NamedTemporaryFile(delete=False, suffix='.pkl').name
    
    def create_embeddings(self, texts: List[str]) -> np.ndarray:
        """Create embeddings for texts"""
        print(f"Creating embeddings for {len(texts)} chunks...")
        batch_size = 32
        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            batch_embeddings = self.embedding_model.encode(batch, show_progress_bar=True)
            all_embeddings.append(batch_embeddings)
        
        return np.vstack(all_embeddings)
    
    def build_index(self, chunks: List[Dict]):
        """Build FAISS index from chunks"""
        self.chunks = chunks
        texts = [chunk['text'] for chunk in chunks]
        embeddings = self.create_embeddings(texts)
        
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(embeddings.astype('float32'))
        
        self.save()
        return True
    
    def search(self, query: str, k: int = MAX_RETRIEVAL_CHUNKS) -> List[Tuple[str, float]]:
        """Search for similar chunks"""
        if self.index is None or not self.chunks:
            return []
        
        query_embedding = self.embedding_model.encode([query])
        distances, indices = self.index.search(query_embedding.astype('float32'), k)
        
        results = []
        for i, idx in enumerate(indices[0]):
            if idx != -1 and idx < len(self.chunks):
                results.append((self.chunks[idx]['text'], float(distances[0][i])))
        
        return results
    
    def save(self):
        if self.index:
            faiss.write_index(self.index, self.index_path)
        with open(self.chunks_path, 'wb') as f:
            pickle.dump(self.chunks, f)
        print("Database saved successfully")
    
    def load(self):
        if os.path.exists(self.index_path) and os.path.exists(self.chunks_path):
            self.index = faiss.read_index(self.index_path)
            with open(self.chunks_path, 'rb') as f:
                self.chunks = pickle.load(f)
            print(f"Loaded database with {len(self.chunks)} chunks")
            return True
        return False
    
    def clear(self):
        self.index = None
        self.chunks = []
        if os.path.exists(self.index_path):
            os.remove(self.index_path)
        if os.path.exists(self.chunks_path):
            os.remove(self.chunks_path)
        print("Database cleared")

# ============================================
# RAG Question Answering
# ============================================

class RAGQA:
    """Handles RAG-based question answering using Groq directly"""
    
    def __init__(self):
        self.vector_db = VectorDatabase()
        self.client = groq.Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
        self.current_transcript_text = ""
        self.vector_db.load()
    
    def truncate_context(self, context: str, max_tokens: int = MAX_CONTEXT_TOKENS) -> str:
        max_chars = max_tokens * 4
        if len(context) > max_chars:
            return context[:max_chars] + "..."
        return context
    
    def process_video(self, youtube_url: str) -> Tuple[str, str, bool]:
        """Process YouTube video and build vector database, return full transcript"""
        # Extract transcript
        transcript, error = YouTubeTranscriptProcessor.extract_transcript(youtube_url)
        if error:
            return error, "", False
        
        if not transcript:
            return "No transcript data found", "", False
        
        # Get full transcript text without timestamps
        self.current_transcript_text = YouTubeTranscriptProcessor.get_full_transcript_text(transcript)
        
        # Chunk transcript for RAG
        chunks = YouTubeTranscriptProcessor.chunk_transcript(transcript)
        
        if not chunks:
            return "No content to process", self.current_transcript_text, False
        
        # Build vector database
        self.vector_db.build_index(chunks)
        
        return f"βœ… Successfully processed {len(chunks)} chunks from video!", self.current_transcript_text, True
    
    def ask_question(self, question: str) -> str:
        """Answer question using RAG with Groq"""
        if not GROQ_API_KEY:
            return "⚠️ Please set your Groq API key in Hugging Face Secrets."
        
        if self.vector_db.index is None or not self.vector_db.chunks:
            return "⚠️ Please load a video transcript first (click 'Get Transcript') before asking questions."
        
        relevant_chunks = self.vector_db.search(question, k=MAX_RETRIEVAL_CHUNKS)
        
        if not relevant_chunks:
            return "❓ No relevant information found in the transcript. Please try a different question."
        
        context = "\n\n---\n\n".join([chunk[0] for chunk in relevant_chunks])
        context = self.truncate_context(context, MAX_CONTEXT_TOKENS)
        
        system_prompt = """Answer questions based ONLY on the provided transcript context. Be brief (2-3 sentences max). If the answer isn't in the context, say so."""
        user_prompt = f"""Context: {context}\n\nQuestion: {question}\n\nAnswer:"""
        
        try:
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                model="llama-3.1-8b-instant",
                temperature=0.3,
                max_tokens=150
            )
            
            return chat_completion.choices[0].message.content
            
        except Exception as e:
            if "rate_limit_exceeded" in str(e) or "too large" in str(e):
                return "⚠️ Context too large. Please ask a more specific question."
            return f"❌ Error: {str(e)}"
    
    def clear_database(self) -> str:
        self.vector_db.clear()
        self.current_transcript_text = ""
        return "πŸ—‘οΈ Database cleared successfully!"

# ============================================
# Gradio UI Application
# ============================================

# Initialize RAG system
rag_system = RAGQA()

def process_youtube_url(youtube_url):
    if not youtube_url or youtube_url.strip() == "":
        return "❌ Please enter a YouTube URL", "⚠️ Waiting for video...", ""
    
    message, transcript_text, success = rag_system.process_video(youtube_url)
    if success:
        return message, "βœ… Ready for questions!", transcript_text
    else:
        return message, "❌ Failed to process video", ""

def answer_question(question, history):
    if not question or question.strip() == "":
        return history
    
    answer = rag_system.ask_question(question)
    history = history or []
    history.append((question, answer))
    return history

def clear_everything():
    message = rag_system.clear_database()
    return message, "⚠️ Waiting for video...", "", []

# Create Gradio interface
with gr.Blocks(title="πŸŽ₯ YouTube Video RAG Q&A", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # πŸ“š YouTube Video Q&A with RAG
    ### Extract transcript and ask questions about any YouTube video!
    
    **How it works:**
    1. Enter a YouTube URL
    2. Click "Get Transcript" to extract and process the video transcript
    3. Ask questions about the video content
    4. Get accurate answers based solely on the transcript
    
    **Note:** Make sure the video has captions/transcripts enabled.
    """)
    
    with gr.Row():
        with gr.Column(scale=3):
            youtube_url = gr.Textbox(
                label="πŸ”— YouTube URL",
                placeholder="https://www.youtube.com/watch?v=...",
                lines=1
            )
        
        with gr.Column(scale=1):
            process_btn = gr.Button("🎬 Get Transcript", variant="primary", size="lg")
    
    with gr.Row():
        status_text = gr.Textbox(label="πŸ“Š Status", interactive=False, lines=2)
        qa_status = gr.Textbox(label="QA Status", interactive=False, lines=1, value="⚠️ Waiting for video...")
    
    gr.Markdown("---")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“ Complete Transcript")
            transcript_display = gr.Textbox(
                label="",
                interactive=False,
                lines=25,
                max_lines=25,
                placeholder="Transcript will appear here after processing..."
            )
        
        with gr.Column(scale=1):
            gr.Markdown("### πŸ’¬ Ask Questions")
            chatbot = gr.Chatbot(
                label="Chat",
                height=400,
                bubble_full_width=False,
                avatar_images=(None, "πŸ€–")
            )
            
            with gr.Row():
                question = gr.Textbox(
                    label="Your Question",
                    placeholder="Ask about the video...",
                    lines=2,
                    scale=4
                )
                submit_btn = gr.Button("Ask", variant="primary", scale=1)
            
            with gr.Row():
                clear_chat_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", size="sm")
                clear_all_btn = gr.Button("πŸ”„ Clear All", variant="stop", size="sm")
    
    # Event handlers
    process_btn.click(
        process_youtube_url,
        inputs=[youtube_url],
        outputs=[status_text, qa_status, transcript_display]
    )
    
    submit_btn.click(
        answer_question,
        inputs=[question, chatbot],
        outputs=[chatbot]
    ).then(
        lambda: "", None, [question]
    )
    
    clear_chat_btn.click(
        lambda: [], None, [chatbot]
    )
    
    clear_all_btn.click(
        clear_everything,
        outputs=[status_text, qa_status, transcript_display, chatbot]
    )
    
    question.submit(
        answer_question,
        inputs=[question, chatbot],
        outputs=[chatbot]
    ).then(
        lambda: "", None, [question]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()