from pathlib import Path import os from os import path as osp import gradio as gr from dotenv import load_dotenv from crud.vector_store import MultimodalLanceDB from preprocess.embedding import BridgeTowerEmbeddings from preprocess.preprocessing import extract_and_save_frames_and_metadata from utils import ( download_video, get_transcript_vtt, download_youtube_subtitle, get_video_id_from_url, str2time, maintain_aspect_ratio_resize, getSubs, encode_image, ) from mistralai import Mistral from langchain_core.runnables import ( RunnableParallel, RunnablePassthrough, RunnableLambda ) from PIL import Image import lancedb import ssl import socket import subprocess import sys # ------------------------------- # 1. Setup - HuggingFace Spaces Configuration # ------------------------------- load_dotenv() # HuggingFace Spaces specific setup SPACE_ID = os.getenv("SPACE_ID") IS_SPACES = SPACE_ID is not None ''' if IS_SPACES: LANCEDB_HOST_FILE = "/tmp/.lancedb" VIDEO_DIR = "/tmp/videos/video1" os.makedirs("/tmp", exist_ok=True) else: LANCEDB_HOST_FILE = "./shared_data/.lancedb" VIDEO_DIR = "./shared_data/videos/video1" ''' LANCEDB_HOST_FILE = "src/shared_data/.lancedb" VIDEO_DIR = "src/shared_data/videos/video1" TBL_NAME = "vectorstore" # Initialize components db = lancedb.connect(LANCEDB_HOST_FILE) embedder = BridgeTowerEmbeddings() # Fix DNS resolution issues in containerized environments def fix_dns_resolution(): """Configure DNS settings for containerized environments""" try: # Set DNS servers import os os.environ['PYTHONHTTPSVERIFY'] = '0' # Create unverified SSL context ssl._create_default_https_context = ssl._create_unverified_context # Set socket timeout socket.setdefaulttimeout(30) except Exception as e: print(f"DNS fix warning: {e}") # Call the fix before any downloads fix_dns_resolution() # Alternative download function using subprocess def download_video_robust(video_url, output_dir): """More robust video download using subprocess""" try: # Ensure directory exists os.makedirs(output_dir, exist_ok=True) # Use yt-dlp with specific options for containerized environments cmd = [ 'yt-dlp', '--no-check-certificates', '--geo-bypass', '--no-warnings', '--output', f'{output_dir}/%(title)s.%(ext)s', '--format', 'best[height<=720]', video_url ] result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) if result.returncode != 0: raise Exception(f"yt-dlp failed: {result.stderr}") # Find the downloaded file for file in os.listdir(output_dir): if file.endswith(('.mp4', '.webm', '.mkv')): return os.path.join(output_dir, file) raise Exception("No video file found after download") except subprocess.TimeoutExpired: raise Exception("Download timeout - video may be too long or network issues") except Exception as e: raise Exception(f"Download failed: {str(e)}") # ------------------------------- # 2. Preprocessing + Storage # ------------------------------- def preprocess_and_store(youtube_url: str): """Download video, extract frames+metadata, embed & store in LanceDB""" try: video_url = youtube_url ''' if os.getenv("SPACE_ID"): video_dir = "/tmp/videos/video1" else: video_dir = "./shared_data/videos/video1" ''' video_dir = "src/shared_data/videos/video1" # Use the robust download function #video_filepath = download_video_robust(video_url, video_dir) video_filepath = "src/shared_data/videos/video1/Welcome back to Planet Earth.mp4" # Try to download subtitle, but don't fail if it doesn't work try: #video_transcript_filepath = download_youtube_subtitle(video_url, video_dir) video_transcript_filepath = "src/shared_data/videos/video1/generated_captions.vtt" except Exception as e: print(f"Warning: Could not download subtitles: {e}") # Create empty transcript file video_transcript_filepath = os.path.join(video_dir, "empty.vtt") with open(video_transcript_filepath, 'w') as f: f.write("WEBVTT\n\n00:00:00.000 --> 00:00:10.000\nNo transcript available\n") extracted_frames_path = os.path.join(video_dir, 'extracted_frame') # Create output folders Path(extracted_frames_path).mkdir(parents=True, exist_ok=True) Path(video_dir).mkdir(parents=True, exist_ok=True) # Extract frames and metadata metadatas = extract_and_save_frames_and_metadata( video_filepath, video_transcript_filepath, extracted_frames_path, video_dir, ) # Process transcripts and images video_trans = [vid['transcript'] for vid in metadatas] video_img_path = [vid['extracted_frame_path'] for vid in metadatas] n = 7 updated_video_trans = [ ' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else ' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans)) ] # Update metadata with new transcripts for i in range(len(updated_video_trans)): metadatas[i]['transcript'] = updated_video_trans[i] # Store in vector database _ = MultimodalLanceDB.from_text_image_pairs( texts=updated_video_trans, image_paths=video_img_path, embedding=embedder, metadatas=metadatas, connection=db, table_name=TBL_NAME, mode="overwrite", ) return f"✅ Video processed and stored: {youtube_url}" except Exception as e: return f"❌ Error processing video: {str(e)}" # ------------------------------- # 3. Retrieval + Prompt Functions # ------------------------------- vectorstore = MultimodalLanceDB( uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME ) retriever_module = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 3} ) def prompt_processing(input): retrieved_results = input["retrieved_results"] user_query = input["user_query"] if not retrieved_results: return {"prompt": "No relevant content found.", "frame_path": None} retrieved_results = retrieved_results[0] prompt_template = ( "The transcript associated with the image is '{transcript}'. " "{user_query}" ) retrieved_metadata = retrieved_results.metadata transcript = retrieved_metadata["transcript"] frame_path = retrieved_metadata["extracted_frame_path"] return { "prompt": prompt_template.format(transcript=transcript, user_query=user_query), "frame_path": frame_path, } def lvlm_inference(input): try: # get the retrieved results and user's query lvlm_prompt = input['prompt'] frame_path = input['frame_path'] if frame_path is None: return "No relevant frame found.", None # Retrieve the API key from environment variables api_key = os.getenv("MISTRAL_API_KEY") if not api_key: return "❌ MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path # Initialize the Mistral client client = Mistral(api_key=api_key) base64_image = encode_image(frame_path) # Define the messages for the chat messages = [ { "role": "user", "content": [ { "type": "text", "text": lvlm_prompt }, { "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}" } ] } ] # Get the chat response chat_response = client.chat.complete( model="pixtral-12b-2409", messages=messages ) return chat_response.choices[0].message.content, frame_path except Exception as e: return f"❌ Error in inference: {str(e)}", frame_path # LangChain Runnable chain prompt_processing_module = RunnableLambda(prompt_processing) lvlm_inference_module = RunnableLambda(lvlm_inference) mm_rag_chain = ( RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()}) | prompt_processing_module | lvlm_inference_module ) # ------------------------------- # 4. Chat API for Gradio # ------------------------------- video_loaded = False def load_video(youtube_url): global video_loaded try: if not youtube_url or not youtube_url.strip(): return "❌ Please enter a valid YouTube URL" status = preprocess_and_store(youtube_url) if status.startswith("✅"): video_loaded = True else: video_loaded = False return status except Exception as e: video_loaded = False return f"❌ Unexpected error: {str(e)}" def chat_interface(message, history): if not video_loaded: return "", history + [(message, "❌ Please load a video first in the 'Load Video' tab.")], None if not message.strip(): return "", history, None try: final_text_response, frame_path = mm_rag_chain.invoke(message) history.append((message, final_text_response)) # Load and return the image retrieved_image = None if frame_path: try: retrieved_image = Image.open(frame_path) except Exception as e: print(f"Error loading image: {e}") return "", history, retrieved_image except Exception as e: error_msg = f"❌ Error processing query: {str(e)}" history.append((message, error_msg)) return "", history, None def clear_chat(): return [], None # ------------------------------- # 5. Enhanced Gradio Interface # ------------------------------- with gr.Blocks( title="Multimodal RAG Video Chat", theme=gr.themes.Default() ) as demo: gr.Markdown(""" # 🎬 Multimodal RAG Video Chat Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model! ⚠️ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work. """) with gr.Tab("1. Load Video"): with gr.Column(): youtube_url = gr.Textbox( label="YouTube URL", value="https://www.youtube.com/watch?v=7Hcg-rLYwdM", interactive=False, lines=1, scale=4 ) with gr.Row(): load_btn = gr.Button("🔄 Process Video", variant="primary", scale=1) status = gr.Textbox( label="Status", interactive=False, lines=2 ) load_btn.click( fn=load_video, inputs=youtube_url, outputs=status, show_progress=True ) with gr.Tab("2. Chat with Video"): with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Chat about the video", height=500 ) with gr.Column(scale=1): retrieved_image = gr.Image( label="Retrieved Frame", height=400, show_label=True, interactive=False ) with gr.Row(): with gr.Column(scale=4): msg = gr.Textbox( label="Your question", placeholder="Ask something about the video content...", lines=2, container=False ) with gr.Column(scale=1, min_width=100): send_btn = gr.Button("📤 Send", variant="primary") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Event handlers msg.submit( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot, retrieved_image], show_progress=True ) send_btn.click( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot, retrieved_image], show_progress=True ) clear_btn.click( fn=clear_chat, outputs=[chatbot, retrieved_image] ) with gr.Tab("📖 Instructions"): gr.Markdown(""" ## How to use this Multimodal RAG system: ### 🔧 Setup: 1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings 2. This app uses Pixtral-12B for vision-language understanding ### 🎥 Load Video: For this demo, we **use a fixed pre-downloaded YouTube video**: 👉 [https://www.youtube.com/watch?v=7Hcg-rLYwdM](https://www.youtube.com/watch?v=7Hcg-rLYwdM) Due to Hugging Face free Space network restrictions, direct downloads from YouTube are disabled. All processing in this demo is based on this preloaded video. ### 💬 Chat with Video: 1. Go to the "Chat with Video" tab 2. Ask questions about the video content 3. The system will retrieve the most relevant frame and provide answers 4. The retrieved frame will be displayed on the right side ## ✨ Features: - 🎥 **Automatic YouTube Processing**: Downloads and processes YouTube videos - 🧠 **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding - 💾 **Vector Storage**: Stores data in LanceDB for fast similarity search - 🤖 **Vision-Language AI**: Powered by Mistral's Pixtral model - 🖼️ **Visual Context**: Shows relevant video frames alongside responses - 🔄 **Real-time Processing**: Fast retrieval and inference ## ⚠️ Limitations: - Works with publicly accessible YouTube videos only - Processing time depends on video length - Requires stable internet connection for video download - API rate limits apply based on Mistral usage ## 🛠️ Technical Stack: - **Embeddings**: BridgeTower (multimodal) - **Vector DB**: LanceDB - **Vision-Language Model**: Pixtral-12B - **Framework**: LangChain + Gradio """) with gr.Tab("🔍 About"): gr.Markdown(""" ## Multimodal RAG Video Chat System This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content. ### Architecture: 1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps 2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content 3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search 4. **Retrieval**: Finds the most relevant video segments based on user queries 5. **Generation**: Uses Pixtral vision-language model to generate contextual responses ### Built with: - **Gradio**: For the web interface - **LangChain**: For orchestrating the RAG pipeline - **LanceDB**: For vector storage and retrieval - **BridgeTower**: For multimodal embeddings - **Mistral Pixtral**: For vision-language understanding --- 💡 **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video. """) # ------------------------------- # 6. Launch Configuration # ------------------------------- if __name__ == "__main__": print('🚀 Starting Multimodal RAG Video Chat App...') # Check for required environment variables if not os.getenv("MISTRAL_API_KEY"): print("⚠️ WARNING: MISTRAL_API_KEY not found in environment variables") print(" Please set this in your HuggingFace Space settings") # Launch with appropriate settings for HF Spaces if IS_SPACES: demo.launch(share=True, server_name="0.0.0.0", server_port=7860) # Use default settings for HF Spaces else: demo.launch(share=True, server_name="0.0.0.0", server_port=7860)