agnixcode commited on
Commit
23e75f0
·
verified ·
1 Parent(s): e34d257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -177
app.py CHANGED
@@ -1,192 +1,350 @@
1
- # ================================
2
- # INSTALL DEPENDENCIES
3
- # ================================
4
- # pip install sentence-transformers faiss-cpu gradio groq requests
5
-
6
- # ================================
7
- # IMPORTS
8
- # ================================
9
- import requests
10
- from sentence_transformers import SentenceTransformer
11
- import faiss
12
- import numpy as np
13
- import gradio as gr
14
- from groq import Groq
15
- import re
16
- import os
17
 
18
- # ================================
19
- # CONFIG
20
- # ================================
21
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
22
- SUPADATA_API_KEY = os.getenv("SUPADATA_API_KEY")
23
-
24
- client = Groq(api_key=GROQ_API_KEY)
25
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
26
-
27
- # Global store
28
- vector_store = None
29
- stored_chunks = []
30
-
31
- # ================================
32
- # UTIL: EXTRACT VIDEO ID
33
- # ================================
34
- def extract_video_id(url):
35
- match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11})", url)
36
- return match.group(1) if match else None
37
-
38
- # ================================
39
- # STEP 1: GET TRANSCRIPT
40
- # Using Supadata API — works from any cloud server (no IP blocks)
41
- # ================================
42
- def get_transcript(url):
43
- video_id = extract_video_id(url)
44
- if not video_id:
45
- return "❌ Invalid YouTube URL"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
- response = requests.get(
49
- "https://api.supadata.ai/v1/youtube/transcript",
50
- params={"videoId": video_id, "text": "true"},
51
- headers={"x-api-key": SUPADATA_API_KEY},
52
- timeout=30
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- if response.status_code == 401:
56
- return "❌ Invalid Supadata API key. Check your HF secret: SUPADATA_API_KEY"
57
- if response.status_code == 404:
58
- return "❌ No transcript found for this video (it may have captions disabled)"
59
- if response.status_code != 200:
60
- return f"❌ Supadata API error {response.status_code}: {response.text}"
61
-
62
- data = response.json()
63
-
64
- # text=true returns content as a plain string
65
- content = data.get("content", "")
66
- if not content:
67
- return "❌ Transcript is empty"
68
-
69
- return content
70
-
71
- except Exception as e:
72
- return f"❌ Transcript Error: {str(e)}"
73
-
74
- # ================================
75
- # STEP 2: CHUNKING
76
- # ================================
77
- def chunk_text(text, chunk_size=300):
78
- words = text.split()
79
- chunks = []
80
- for i in range(0, len(words), chunk_size):
81
- chunk = " ".join(words[i:i + chunk_size])
82
- chunks.append(chunk)
83
- return chunks
84
-
85
- # ================================
86
- # STEP 3: VECTOR STORE
87
- # ================================
88
- def create_vector_store(chunks):
89
- global vector_store, stored_chunks
90
- embeddings = embed_model.encode(chunks)
91
  dim = embeddings.shape[1]
92
  index = faiss.IndexFlatL2(dim)
93
- index.add(np.array(embeddings))
94
- vector_store = index
95
- stored_chunks = chunks
96
-
97
- # ================================
98
- # STEP 4: RETRIEVAL
99
- # ================================
100
- def retrieve(query, top_k=3):
101
- query_embedding = embed_model.encode([query])
102
- distances, indices = vector_store.search(np.array(query_embedding), top_k)
103
- results = [stored_chunks[i] for i in indices[0]]
104
- return "\n".join(results)
105
-
106
- # ================================
107
- # STEP 5: LLM (GROQ)
108
- # ================================
109
- def generate_answer(query, context):
110
- prompt = f"""You are a helpful assistant.
111
-
112
- Use ONLY the context below to answer the question.
113
-
114
- Context:
115
- {context}
116
-
117
- Question:
118
- {query}
119
-
120
- Answer:"""
121
-
122
- response = client.chat.completions.create(
123
- model="llama-3.3-70b-versatile",
124
- messages=[{"role": "user", "content": prompt}],
125
- temperature=0.3
126
  )
127
- return response.choices[0].message.content
128
-
129
- # ================================
130
- # HANDLERS
131
- # ================================
132
- def handle_process(url):
133
- transcript = get_transcript(url)
134
- if transcript.startswith(""):
135
- return transcript, "", []
136
- chunks = chunk_text(transcript)
137
- create_vector_store(chunks)
138
- preview = transcript[:500]
139
- return "✅ Video processed successfully!", preview, []
140
-
141
- def handle_chat(query, chat_history):
142
- if vector_store is None:
143
- return "", chat_history + [(query, "❌ Process a video first")]
144
- context = retrieve(query)
145
- answer = generate_answer(query, context)
146
- chat_history.append((query, answer))
147
- return "", chat_history
148
-
149
- # ================================
150
- # UI
151
- # ================================
152
- with gr.Blocks(theme=gr.themes.Soft()) as app:
153
-
154
- gr.Markdown("# 🎥 YouTube Video Assistant")
155
- gr.Markdown("Paste a YouTube link → process → chat with the video")
156
-
157
- with gr.Row():
158
- url_input = gr.Textbox(label="🔗 YouTube URL", scale=4)
159
- process_btn = gr.Button("🚀 Process", scale=1)
160
-
161
- status_output = gr.Markdown("")
162
-
163
- transcript_preview = gr.Textbox(
164
- label="📄 Transcript Preview",
165
- lines=5,
166
- interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
 
169
- gr.Markdown("---")
 
 
 
 
 
170
 
171
- chatbot = gr.Chatbot(label="💬 Chat with Video")
 
 
 
172
 
173
- with gr.Row():
174
- query_input = gr.Textbox(
175
- placeholder="Ask something about the video...",
176
- scale=4
 
 
 
 
 
 
 
 
177
  )
178
- send_btn = gr.Button("Send", scale=1)
179
-
180
- process_btn.click(
181
- handle_process,
182
- inputs=url_input,
183
- outputs=[status_output, transcript_preview, chatbot]
184
- )
185
 
186
- send_btn.click(
187
- handle_chat,
188
- inputs=[query_input, chatbot],
189
- outputs=[query_input, chatbot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
 
192
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import os
4
+ import re
5
+ import gradio as gr
6
+ import numpy as np
7
+ import faiss
8
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from sentence_transformers import SentenceTransformer
13
+ from huggingface_hub import InferenceClient
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Global state
17
+ # ---------------------------------------------------------------------------
18
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
+
20
+ faiss_index: faiss.IndexFlatL2 | None = None
21
+ chunk_store: list[str] = [] # parallel list of text chunks
22
+ full_transcript: str = "" # raw transcript for display
23
+
24
+ # HF Inference API – set HF_TOKEN as a Space secret or environment variable.
25
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
26
+ LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" # swap freely
27
+
28
+ inference_client = InferenceClient(model=LLM_MODEL, token=HF_TOKEN or None)
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Helper – extract video id from various YouTube URL formats
32
+ # ---------------------------------------------------------------------------
33
+ def _extract_video_id(url: str) -> str:
34
+ """Return the 11-char YouTube video ID from any common URL format."""
35
+ patterns = [
36
+ r"(?:v=)([A-Za-z0-9_-]{11})", # ?v=xxxx
37
+ r"(?:youtu\.be/)([A-Za-z0-9_-]{11})", # short link
38
+ r"(?:embed/)([A-Za-z0-9_-]{11})", # embed link
39
+ r"(?:shorts/)([A-Za-z0-9_-]{11})", # shorts
40
+ ]
41
+ for pattern in patterns:
42
+ match = re.search(pattern, url)
43
+ if match:
44
+ return match.group(1)
45
+ raise ValueError(f"Could not extract a valid video ID from URL: {url}")
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # 1. Fetch transcript
50
+ # ---------------------------------------------------------------------------
51
+ def get_transcript(url: str) -> str:
52
+ """
53
+ Fetch the transcript for a YouTube video.
54
+
55
+ Returns the full transcript as a single string.
56
+ Raises ValueError with a human-readable message on failure.
57
+ """
58
+ video_id = _extract_video_id(url)
59
  try:
60
+ transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
61
+ except TranscriptsDisabled:
62
+ raise ValueError("Transcripts are disabled for this video.")
63
+ except NoTranscriptFound:
64
+ # Try fetching any available language and translating to English
65
+ try:
66
+ transcript_list = (
67
+ YouTubeTranscriptApi.list_transcripts(video_id)
68
+ .find_generated_transcript(["en", "en-US", "en-GB"])
69
+ .fetch()
70
+ )
71
+ except Exception as inner_exc:
72
+ raise ValueError(
73
+ f"No transcript found for this video. Details: {inner_exc}"
74
+ )
75
+ except Exception as exc:
76
+ raise ValueError(f"Failed to retrieve transcript: {exc}")
77
+
78
+ # Concatenate all segments into a single string
79
+ return " ".join(seg["text"] for seg in transcript_list)
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # 2. Process video – build FAISS index
84
+ # ---------------------------------------------------------------------------
85
+ def process_video(url: str):
86
+ """
87
+ Full pipeline:
88
+ 1. Fetch transcript
89
+ 2. Split into chunks
90
+ 3. Compute embeddings
91
+ 4. Build FAISS index
92
+
93
+ Returns (status_message, transcript_text) for the Gradio UI.
94
+ """
95
+ global faiss_index, chunk_store, full_transcript
96
+
97
+ # Reset state
98
+ faiss_index = None
99
+ chunk_store = []
100
+ full_transcript = ""
101
+
102
+ # -- Step 1: transcript --------------------------------------------------
103
+ try:
104
+ transcript = get_transcript(url)
105
+ except ValueError as exc:
106
+ return str(exc), ""
107
+
108
+ full_transcript = transcript
109
+
110
+ # -- Step 2: chunking ----------------------------------------------------
111
+ splitter = RecursiveCharacterTextSplitter(
112
+ chunk_size=500,
113
+ chunk_overlap=50,
114
+ length_function=len,
115
+ )
116
+ chunks = splitter.split_text(transcript)
117
+ if not chunks:
118
+ return "Transcript was fetched but produced no text chunks.", transcript
119
 
120
+ chunk_store = chunks
121
+
122
+ # -- Step 3: embeddings --------------------------------------------------
123
+ embeddings = embedding_model.encode(chunks, show_progress_bar=False)
124
+ embeddings = np.array(embeddings, dtype="float32")
125
+
126
+ # -- Step 4: FAISS index -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  dim = embeddings.shape[1]
128
  index = faiss.IndexFlatL2(dim)
129
+ index.add(embeddings)
130
+ faiss_index = index
131
+
132
+ status = (
133
+ f"✅ Video processed successfully!\n"
134
+ f" • Chunks created : {len(chunks)}\n"
135
+ f" • Embedding dim : {dim}\n"
136
+ f" • FAISS vectors : {index.ntotal}\n\n"
137
+ f"Switch to the **Chat with Video** tab to start asking questions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
+ return status, transcript
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # 3. Retrieve top-k chunks
144
+ # ---------------------------------------------------------------------------
145
+ def retrieve_context(query: str, top_k: int = 3) -> str:
146
+ """
147
+ Encode the query and retrieve the top-k most relevant transcript chunks
148
+ from the FAISS index.
149
+
150
+ Returns a single string with the chunks separated by newlines.
151
+ """
152
+ if faiss_index is None or not chunk_store:
153
+ return ""
154
+
155
+ query_vec = embedding_model.encode([query], show_progress_bar=False)
156
+ query_vec = np.array(query_vec, dtype="float32")
157
+
158
+ k = min(top_k, len(chunk_store))
159
+ distances, indices = faiss_index.search(query_vec, k)
160
+
161
+ retrieved = [chunk_store[i] for i in indices[0] if i < len(chunk_store)]
162
+ return "\n\n".join(retrieved)
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # 4. Generate answer via HF Inference API (RAG prompt)
167
+ # ---------------------------------------------------------------------------
168
+ def generate_answer(query: str) -> str:
169
+ """
170
+ Retrieve context chunks and call the LLM to produce a grounded answer.
171
+ The prompt explicitly instructs the model to rely only on the provided
172
+ context and not hallucinate.
173
+ """
174
+ if faiss_index is None:
175
+ return (
176
+ "⚠️ No video has been processed yet. "
177
+ "Please go to the **Process Video** tab and load a YouTube URL first."
178
+ )
179
+
180
+ context = retrieve_context(query, top_k=3)
181
+ if not context:
182
+ return "⚠️ Could not retrieve any relevant context for your question."
183
+
184
+ # RAG prompt – works well with instruction-tuned models
185
+ system_prompt = (
186
+ "You are a helpful assistant that answers questions strictly based on "
187
+ "the provided transcript context. "
188
+ "If the answer is not contained in the context, say: "
189
+ "'I could not find this information in the video transcript.' "
190
+ "Do NOT make up information."
191
  )
192
 
193
+ user_prompt = (
194
+ f"Context from the video transcript:\n"
195
+ f"---\n{context}\n---\n\n"
196
+ f"Question: {query}\n\n"
197
+ f"Answer:"
198
+ )
199
 
200
+ messages = [
201
+ {"role": "system", "content": system_prompt},
202
+ {"role": "user", "content": user_prompt},
203
+ ]
204
 
205
+ try:
206
+ response = inference_client.chat_completion(
207
+ messages=messages,
208
+ max_tokens=512,
209
+ temperature=0.2, # low temperature → more faithful to context
210
+ top_p=0.9,
211
+ )
212
+ answer = response.choices[0].message.content.strip()
213
+ except Exception as exc:
214
+ answer = (
215
+ f"❌ Model inference failed: {exc}\n\n"
216
+ "Make sure HF_TOKEN is set and the model endpoint is available."
217
  )
 
 
 
 
 
 
 
218
 
219
+ return answer
220
+
221
+
222
+ # ---------------------------------------------------------------------------
223
+ # 5. Gradio chat helper (maintains history list)
224
+ # ---------------------------------------------------------------------------
225
+ def chat(user_message: str, history: list[list[str]]):
226
+ """
227
+ Called by the Gradio ChatInterface-style callback.
228
+ Appends the new Q-A pair to history and returns updated history.
229
+ """
230
+ if not user_message.strip():
231
+ history.append([user_message, "Please enter a question."])
232
+ return history, ""
233
+
234
+ answer = generate_answer(user_message)
235
+ history.append([user_message, answer])
236
+ return history, ""
237
+
238
+
239
+ # ---------------------------------------------------------------------------
240
+ # 6. Gradio UI
241
+ # ---------------------------------------------------------------------------
242
+ with gr.Blocks(
243
+ title="YouTube RAG Chatbot",
244
+ theme=gr.themes.Soft(),
245
+ ) as app:
246
+
247
+ gr.Markdown(
248
+ """
249
+ # 🎬 YouTube RAG Chatbot
250
+ **Process any YouTube video and chat with its transcript using Retrieval-Augmented Generation.**
251
+
252
+ > **Note:** Set your `HF_TOKEN` environment variable (Space secret) so the LLM inference works.
253
+ """
254
  )
255
 
256
+ with gr.Tabs():
257
+
258
+ # ------------------------------------------------------------------ #
259
+ # Tab 1 – Process Video
260
+ # ------------------------------------------------------------------ #
261
+ with gr.TabItem("📥 Process Video"):
262
+ gr.Markdown(
263
+ "Paste a YouTube URL below and click **Process**. "
264
+ "The transcript will be fetched, chunked, embedded, and indexed."
265
+ )
266
+ with gr.Row():
267
+ url_input = gr.Textbox(
268
+ label="YouTube URL",
269
+ placeholder="https://www.youtube.com/watch?v=...",
270
+ scale=5,
271
+ )
272
+ process_btn = gr.Button("⚙️ Process", variant="primary", scale=1)
273
+
274
+ status_output = gr.Textbox(
275
+ label="Status",
276
+ lines=6,
277
+ interactive=False,
278
+ )
279
+ transcript_output = gr.Textbox(
280
+ label="Transcript (read-only)",
281
+ lines=15,
282
+ interactive=False,
283
+ show_copy_button=True,
284
+ )
285
+
286
+ process_btn.click(
287
+ fn=process_video,
288
+ inputs=[url_input],
289
+ outputs=[status_output, transcript_output],
290
+ )
291
+
292
+ # ------------------------------------------------------------------ #
293
+ # Tab 2 – Chat with Video
294
+ # ------------------------------------------------------------------ #
295
+ with gr.TabItem("💬 Chat with Video"):
296
+ gr.Markdown(
297
+ "Ask any question about the processed video. "
298
+ "The bot retrieves the most relevant transcript segments "
299
+ "and generates a grounded answer."
300
+ )
301
+
302
+ chatbot = gr.Chatbot(
303
+ label="Conversation",
304
+ height=450,
305
+ bubble_full_width=False,
306
+ )
307
+
308
+ with gr.Row():
309
+ query_input = gr.Textbox(
310
+ label="Your question",
311
+ placeholder="What is the main topic discussed in this video?",
312
+ scale=5,
313
+ )
314
+ send_btn = gr.Button("Send 🚀", variant="primary", scale=1)
315
+
316
+ clear_btn = gr.Button("🗑️ Clear conversation", variant="secondary")
317
+
318
+ # Shared state for conversation history
319
+ chat_history = gr.State([])
320
+
321
+ send_btn.click(
322
+ fn=chat,
323
+ inputs=[query_input, chat_history],
324
+ outputs=[chatbot, query_input],
325
+ ).then(
326
+ fn=lambda h: h,
327
+ inputs=[chatbot],
328
+ outputs=[chat_history],
329
+ )
330
+
331
+ query_input.submit(
332
+ fn=chat,
333
+ inputs=[query_input, chat_history],
334
+ outputs=[chatbot, query_input],
335
+ ).then(
336
+ fn=lambda h: h,
337
+ inputs=[chatbot],
338
+ outputs=[chat_history],
339
+ )
340
+
341
+ clear_btn.click(
342
+ fn=lambda: ([], []),
343
+ outputs=[chatbot, chat_history],
344
+ )
345
+
346
+ # ---------------------------------------------------------------------------
347
+ # Entry point
348
+ # ---------------------------------------------------------------------------
349
+ if __name__ == "__main__":
350
+ app.launch()