Claude commited on
Commit
e95686f
·
unverified ·
1 Parent(s): f2820bb

feat: Add vector DB and RAG chatbot

Browse files

- Add ChromaDB for vector storage of video content
- Add sentence-transformers for embeddings
- Add FLAN-T5 for chat responses
- Store transcripts and visual context in vector DB
- Add 'Chat with Videos' tab with RAG-based Q&A
- Add requirements.txt for HuggingFace Spaces compatibility
- Chunk text with overlap for better retrieval

Files changed (4) hide show
  1. app.py +237 -48
  2. pyproject.toml +2 -0
  3. requirements.txt +10 -0
  4. uv.lock +0 -0
app.py CHANGED
@@ -1,17 +1,38 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import tempfile
 
5
  from pathlib import Path
6
 
 
7
  import cv2
8
  import gradio as gr
9
  import torch
10
  import yt_dlp
11
  from huggingface_hub import whoami
12
  from PIL import Image
 
13
  from transformers import BlipForConditionalGeneration, BlipProcessor, pipeline
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def hello(profile: gr.OAuthProfile | None) -> str:
17
  if profile is None:
@@ -49,6 +70,14 @@ def get_vision_model():
49
  return processor, model
50
 
51
 
 
 
 
 
 
 
 
 
52
  def download_video(url: str, output_dir: str) -> list[dict]:
53
  """Download video from YouTube URL (video or playlist)."""
54
  ydl_opts = {
@@ -83,18 +112,6 @@ def download_video(url: str, output_dir: str) -> list[dict]:
83
  def extract_audio(video_path: str, output_dir: str) -> str:
84
  """Extract audio from video file."""
85
  audio_path = os.path.join(output_dir, "audio.mp3")
86
- ydl_opts = {
87
- "format": "bestaudio/best",
88
- "postprocessors": [{
89
- "key": "FFmpegExtractAudio",
90
- "preferredcodec": "mp3",
91
- "preferredquality": "192",
92
- }],
93
- "outtmpl": os.path.join(output_dir, "audio"),
94
- "quiet": True,
95
- }
96
- # Use ffmpeg directly via yt-dlp's post-processor on local file
97
- import subprocess
98
  subprocess.run([
99
  "ffmpeg", "-i", video_path, "-vn", "-acodec", "libmp3lame",
100
  "-q:a", "2", audio_path, "-y"
@@ -112,14 +129,12 @@ def extract_frames(video_path: str, num_frames: int = 5) -> list[Image.Image]:
112
  cap.release()
113
  return frames
114
 
115
- # Get evenly spaced frame indices
116
  indices = [int(i * total_frames / (num_frames + 1)) for i in range(1, num_frames + 1)]
117
 
118
  for idx in indices:
119
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
120
  ret, frame = cap.read()
121
  if ret:
122
- # Convert BGR to RGB
123
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
124
  frames.append(Image.fromarray(frame_rgb))
125
 
@@ -143,6 +158,80 @@ def transcribe_audio(audio_path: str, whisper_model) -> str:
143
  return result["text"]
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def process_youtube(
147
  url: str,
148
  analyze_frames: bool,
@@ -171,10 +260,9 @@ def process_youtube(
171
  total = len(downloaded)
172
 
173
  for i, item in enumerate(downloaded):
174
- base_progress = 0.1 + 0.9 * (i / total)
175
  video_result = [f"## {item['title']}"]
176
 
177
- # Find the actual video file
178
  video_files = list(Path(tmpdir).glob("*.mp4")) + \
179
  list(Path(tmpdir).glob("*.webm")) + \
180
  list(Path(tmpdir).glob("*.mkv"))
@@ -187,27 +275,35 @@ def process_youtube(
187
  video_path = str(video_files[0])
188
 
189
  # Extract and transcribe audio
190
- progress(base_progress + 0.3 * (1/total), desc=f"Extracting audio: {item['title']}")
191
  audio_path = extract_audio(video_path, tmpdir)
192
 
193
- progress(base_progress + 0.5 * (1/total), desc=f"Transcribing: {item['title']}")
194
  transcript = transcribe_audio(audio_path, whisper_model)
195
 
 
 
196
  if transcript:
197
  video_result.append("### Transcript")
198
  video_result.append(transcript)
199
 
200
  # Analyze frames if enabled
201
  if analyze_frames:
202
- progress(base_progress + 0.7 * (1/total), desc=f"Analyzing frames: {item['title']}")
203
  frames = extract_frames(video_path, num_frames)
204
 
205
  if frames:
206
  video_result.append("\n### Visual Context")
207
  for j, frame in enumerate(frames):
208
  caption = describe_frame(frame, vision_processor, vision_model)
 
209
  video_result.append(f"**Frame {j+1}:** {caption}")
210
 
 
 
 
 
 
211
  results.append("\n\n".join(video_result))
212
 
213
  progress(1.0, desc="Done!")
@@ -217,9 +313,70 @@ def process_youtube(
217
  return f"Error: {e!s}"
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  with gr.Blocks() as demo:
221
  gr.Markdown("# Video Analyzer")
222
- gr.Markdown("Download, transcribe, and analyze YouTube videos using AI")
223
 
224
  gr.LoginButton()
225
  m1 = gr.Markdown()
@@ -227,34 +384,66 @@ with gr.Blocks() as demo:
227
 
228
  gr.Markdown("---")
229
 
230
- with gr.Row():
231
- url_input = gr.Textbox(
232
- label="YouTube URL",
233
- placeholder="Enter a YouTube video or playlist URL",
234
- scale=4,
235
- )
236
-
237
- with gr.Row():
238
- analyze_frames = gr.Checkbox(
239
- label="Analyze video frames (visual context)",
240
- value=True,
241
- )
242
- num_frames = gr.Slider(
243
- label="Number of frames to analyze",
244
- minimum=1,
245
- maximum=10,
246
- value=5,
247
- step=1,
248
- )
249
-
250
- submit_btn = gr.Button("Analyze Video", variant="primary")
251
- output = gr.Markdown(label="Analysis")
252
-
253
- submit_btn.click(
254
- fn=process_youtube,
255
- inputs=[url_input, analyze_frames, num_frames],
256
- outputs=[output],
257
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  demo.load(hello, inputs=None, outputs=m1)
260
  demo.load(list_organizations, inputs=None, outputs=m2)
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import subprocess
5
  import tempfile
6
+ import uuid
7
  from pathlib import Path
8
 
9
+ import chromadb
10
  import cv2
11
  import gradio as gr
12
  import torch
13
  import yt_dlp
14
  from huggingface_hub import whoami
15
  from PIL import Image
16
+ from sentence_transformers import SentenceTransformer
17
  from transformers import BlipForConditionalGeneration, BlipProcessor, pipeline
18
 
19
+ # Initialize ChromaDB client (persistent storage)
20
+ chroma_client = chromadb.Client()
21
+ collection = chroma_client.get_or_create_collection(
22
+ name="video_knowledge",
23
+ metadata={"hnsw:space": "cosine"}
24
+ )
25
+
26
+ # Global embedding model
27
+ embedding_model = None
28
+
29
+
30
+ def get_embedding_model():
31
+ global embedding_model
32
+ if embedding_model is None:
33
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
34
+ return embedding_model
35
+
36
 
37
  def hello(profile: gr.OAuthProfile | None) -> str:
38
  if profile is None:
 
70
  return processor, model
71
 
72
 
73
+ def get_text_generation_model():
74
+ return pipeline(
75
+ "text2text-generation",
76
+ model="google/flan-t5-base",
77
+ device=get_device(),
78
+ )
79
+
80
+
81
  def download_video(url: str, output_dir: str) -> list[dict]:
82
  """Download video from YouTube URL (video or playlist)."""
83
  ydl_opts = {
 
112
  def extract_audio(video_path: str, output_dir: str) -> str:
113
  """Extract audio from video file."""
114
  audio_path = os.path.join(output_dir, "audio.mp3")
 
 
 
 
 
 
 
 
 
 
 
 
115
  subprocess.run([
116
  "ffmpeg", "-i", video_path, "-vn", "-acodec", "libmp3lame",
117
  "-q:a", "2", audio_path, "-y"
 
129
  cap.release()
130
  return frames
131
 
 
132
  indices = [int(i * total_frames / (num_frames + 1)) for i in range(1, num_frames + 1)]
133
 
134
  for idx in indices:
135
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
136
  ret, frame = cap.read()
137
  if ret:
 
138
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
139
  frames.append(Image.fromarray(frame_rgb))
140
 
 
158
  return result["text"]
159
 
160
 
161
+ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
162
+ """Split text into overlapping chunks."""
163
+ words = text.split()
164
+ chunks = []
165
+ for i in range(0, len(words), chunk_size - overlap):
166
+ chunk = " ".join(words[i:i + chunk_size])
167
+ if chunk:
168
+ chunks.append(chunk)
169
+ return chunks
170
+
171
+
172
+ def add_to_vector_db(title: str, transcript: str, visual_contexts: list[str]):
173
+ """Add video content to vector database."""
174
+ embed_model = get_embedding_model()
175
+
176
+ documents = []
177
+ metadatas = []
178
+ ids = []
179
+
180
+ # Add transcript chunks
181
+ if transcript:
182
+ chunks = chunk_text(transcript)
183
+ for i, chunk in enumerate(chunks):
184
+ documents.append(chunk)
185
+ metadatas.append({
186
+ "title": title,
187
+ "type": "transcript",
188
+ "chunk_index": i,
189
+ })
190
+ ids.append(f"{title}_transcript_{i}_{uuid.uuid4().hex[:8]}")
191
+
192
+ # Add visual context
193
+ for i, context in enumerate(visual_contexts):
194
+ documents.append(f"Visual scene from {title}: {context}")
195
+ metadatas.append({
196
+ "title": title,
197
+ "type": "visual",
198
+ "frame_index": i,
199
+ })
200
+ ids.append(f"{title}_visual_{i}_{uuid.uuid4().hex[:8]}")
201
+
202
+ if documents:
203
+ embeddings = embed_model.encode(documents).tolist()
204
+ collection.add(
205
+ documents=documents,
206
+ embeddings=embeddings,
207
+ metadatas=metadatas,
208
+ ids=ids,
209
+ )
210
+
211
+ return len(documents)
212
+
213
+
214
+ def search_knowledge(query: str, n_results: int = 5) -> list[dict]:
215
+ """Search the vector database for relevant content."""
216
+ embed_model = get_embedding_model()
217
+ query_embedding = embed_model.encode([query]).tolist()
218
+
219
+ results = collection.query(
220
+ query_embeddings=query_embedding,
221
+ n_results=n_results,
222
+ )
223
+
224
+ matches = []
225
+ if results["documents"] and results["documents"][0]:
226
+ for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
227
+ matches.append({
228
+ "content": doc,
229
+ "title": metadata.get("title", "Unknown"),
230
+ "type": metadata.get("type", "unknown"),
231
+ })
232
+ return matches
233
+
234
+
235
  def process_youtube(
236
  url: str,
237
  analyze_frames: bool,
 
260
  total = len(downloaded)
261
 
262
  for i, item in enumerate(downloaded):
263
+ base_progress = 0.1 + 0.8 * (i / total)
264
  video_result = [f"## {item['title']}"]
265
 
 
266
  video_files = list(Path(tmpdir).glob("*.mp4")) + \
267
  list(Path(tmpdir).glob("*.webm")) + \
268
  list(Path(tmpdir).glob("*.mkv"))
 
275
  video_path = str(video_files[0])
276
 
277
  # Extract and transcribe audio
278
+ progress(base_progress + 0.2 * (1/total), desc=f"Extracting audio: {item['title']}")
279
  audio_path = extract_audio(video_path, tmpdir)
280
 
281
+ progress(base_progress + 0.4 * (1/total), desc=f"Transcribing: {item['title']}")
282
  transcript = transcribe_audio(audio_path, whisper_model)
283
 
284
+ visual_contexts = []
285
+
286
  if transcript:
287
  video_result.append("### Transcript")
288
  video_result.append(transcript)
289
 
290
  # Analyze frames if enabled
291
  if analyze_frames:
292
+ progress(base_progress + 0.6 * (1/total), desc=f"Analyzing frames: {item['title']}")
293
  frames = extract_frames(video_path, num_frames)
294
 
295
  if frames:
296
  video_result.append("\n### Visual Context")
297
  for j, frame in enumerate(frames):
298
  caption = describe_frame(frame, vision_processor, vision_model)
299
+ visual_contexts.append(caption)
300
  video_result.append(f"**Frame {j+1}:** {caption}")
301
 
302
+ # Store in vector DB
303
+ progress(base_progress + 0.8 * (1/total), desc=f"Storing in knowledge base: {item['title']}")
304
+ num_stored = add_to_vector_db(item["title"], transcript, visual_contexts)
305
+ video_result.append(f"\n*Added {num_stored} chunks to knowledge base*")
306
+
307
  results.append("\n\n".join(video_result))
308
 
309
  progress(1.0, desc="Done!")
 
313
  return f"Error: {e!s}"
314
 
315
 
316
+ def chat_with_videos(
317
+ message: str,
318
+ history: list[dict],
319
+ profile: gr.OAuthProfile | None,
320
+ ) -> str:
321
+ if profile is None:
322
+ return "Please log in to use the chat feature."
323
+
324
+ if not message or not message.strip():
325
+ return "Please enter a question."
326
+
327
+ # Check if we have any content in the knowledge base
328
+ if collection.count() == 0:
329
+ return "No videos have been analyzed yet. Please analyze some videos first to build the knowledge base."
330
+
331
+ # Search for relevant context
332
+ matches = search_knowledge(message.strip(), n_results=5)
333
+
334
+ if not matches:
335
+ return "I couldn't find any relevant information in the analyzed videos."
336
+
337
+ # Build context from matches
338
+ context_parts = []
339
+ for match in matches:
340
+ source = f"[{match['title']} - {match['type']}]"
341
+ context_parts.append(f"{source}: {match['content']}")
342
+
343
+ context = "\n\n".join(context_parts)
344
+
345
+ # Generate response using the LLM
346
+ try:
347
+ llm = get_text_generation_model()
348
+ prompt = f"""Based on the following video content, answer the question.
349
+
350
+ Video Content:
351
+ {context}
352
+
353
+ Question: {message}
354
+
355
+ Answer:"""
356
+
357
+ response = llm(prompt, max_length=512, do_sample=False)[0]["generated_text"]
358
+
359
+ # Add sources
360
+ sources = list(set(m["title"] for m in matches))
361
+ response += f"\n\n*Sources: {', '.join(sources)}*"
362
+
363
+ return response
364
+
365
+ except Exception as e:
366
+ return f"Error generating response: {e!s}"
367
+
368
+
369
+ def get_knowledge_stats() -> str:
370
+ """Get statistics about the knowledge base."""
371
+ count = collection.count()
372
+ if count == 0:
373
+ return "Knowledge base is empty. Analyze some videos to get started!"
374
+ return f"Knowledge base contains **{count}** chunks from analyzed videos."
375
+
376
+
377
  with gr.Blocks() as demo:
378
  gr.Markdown("# Video Analyzer")
379
+ gr.Markdown("Download, transcribe, analyze, and chat with YouTube videos using AI")
380
 
381
  gr.LoginButton()
382
  m1 = gr.Markdown()
 
384
 
385
  gr.Markdown("---")
386
 
387
+ with gr.Tabs():
388
+ with gr.TabItem("Analyze Videos"):
389
+ with gr.Row():
390
+ url_input = gr.Textbox(
391
+ label="YouTube URL",
392
+ placeholder="Enter a YouTube video or playlist URL",
393
+ scale=4,
394
+ )
395
+
396
+ with gr.Row():
397
+ analyze_frames = gr.Checkbox(
398
+ label="Analyze video frames (visual context)",
399
+ value=True,
400
+ )
401
+ num_frames = gr.Slider(
402
+ label="Number of frames to analyze",
403
+ minimum=1,
404
+ maximum=10,
405
+ value=5,
406
+ step=1,
407
+ )
408
+
409
+ submit_btn = gr.Button("Analyze Video", variant="primary")
410
+ output = gr.Markdown(label="Analysis")
411
+
412
+ submit_btn.click(
413
+ fn=process_youtube,
414
+ inputs=[url_input, analyze_frames, num_frames],
415
+ outputs=[output],
416
+ )
417
+
418
+ with gr.TabItem("Chat with Videos"):
419
+ kb_stats = gr.Markdown()
420
+ chatbot = gr.Chatbot(label="Video Chat", type="messages")
421
+ chat_input = gr.Textbox(
422
+ label="Ask a question about your videos",
423
+ placeholder="What did the video say about...?",
424
+ )
425
+ chat_btn = gr.Button("Ask", variant="primary")
426
+
427
+ def respond(message, history, profile):
428
+ response = chat_with_videos(message, history, profile)
429
+ history = history or []
430
+ history.append({"role": "user", "content": message})
431
+ history.append({"role": "assistant", "content": response})
432
+ return history, ""
433
+
434
+ chat_btn.click(
435
+ fn=respond,
436
+ inputs=[chat_input, chatbot],
437
+ outputs=[chatbot, chat_input],
438
+ )
439
+ chat_input.submit(
440
+ fn=respond,
441
+ inputs=[chat_input, chatbot],
442
+ outputs=[chatbot, chat_input],
443
+ )
444
+
445
+ # Update stats on tab load
446
+ demo.load(get_knowledge_stats, outputs=kb_stats)
447
 
448
  demo.load(hello, inputs=None, outputs=m1)
449
  demo.load(list_organizations, inputs=None, outputs=m2)
pyproject.toml CHANGED
@@ -13,4 +13,6 @@ dependencies = [
13
  "accelerate>=0.25.0",
14
  "opencv-python-headless>=4.8.0",
15
  "Pillow>=10.0.0",
 
 
16
  ]
 
13
  "accelerate>=0.25.0",
14
  "opencv-python-headless>=4.8.0",
15
  "Pillow>=10.0.0",
16
+ "chromadb>=0.4.0",
17
+ "sentence-transformers>=2.2.0",
18
  ]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=6.0.0
2
+ huggingface_hub>=0.20.0
3
+ yt-dlp>=2024.1.0
4
+ transformers>=4.36.0
5
+ torch>=2.0.0
6
+ accelerate>=0.25.0
7
+ opencv-python-headless>=4.8.0
8
+ Pillow>=10.0.0
9
+ chromadb>=0.4.0
10
+ sentence-transformers>=2.2.0
uv.lock CHANGED
The diff for this file is too large to render. See raw diff