Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import os | |
| from os import path as osp | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from crud.vector_store import MultimodalLanceDB | |
| from preprocess.embedding import BridgeTowerEmbeddings | |
| from preprocess.preprocessing import extract_and_save_frames_and_metadata | |
| from utils import ( | |
| download_video, | |
| get_transcript_vtt, | |
| download_youtube_subtitle, | |
| get_video_id_from_url, | |
| str2time, | |
| maintain_aspect_ratio_resize, | |
| getSubs, | |
| encode_image, | |
| ) | |
| from mistralai import Mistral | |
| from langchain_core.runnables import ( | |
| RunnableParallel, | |
| RunnablePassthrough, | |
| RunnableLambda | |
| ) | |
| from PIL import Image | |
| import lancedb | |
| import ssl | |
| import socket | |
| import subprocess | |
| import sys | |
| # ------------------------------- | |
| # 1. Setup - HuggingFace Spaces Configuration | |
| # ------------------------------- | |
| load_dotenv() | |
| # HuggingFace Spaces specific setup | |
| SPACE_ID = os.getenv("SPACE_ID") | |
| IS_SPACES = SPACE_ID is not None | |
| ''' | |
| if IS_SPACES: | |
| LANCEDB_HOST_FILE = "/tmp/.lancedb" | |
| VIDEO_DIR = "/tmp/videos/video1" | |
| os.makedirs("/tmp", exist_ok=True) | |
| else: | |
| LANCEDB_HOST_FILE = "./shared_data/.lancedb" | |
| VIDEO_DIR = "./shared_data/videos/video1" | |
| ''' | |
| LANCEDB_HOST_FILE = "src/shared_data/.lancedb" | |
| VIDEO_DIR = "src/shared_data/videos/video1" | |
| TBL_NAME = "vectorstore" | |
| # Initialize components | |
| db = lancedb.connect(LANCEDB_HOST_FILE) | |
| embedder = BridgeTowerEmbeddings() | |
| # Fix DNS resolution issues in containerized environments | |
| def fix_dns_resolution(): | |
| """Configure DNS settings for containerized environments""" | |
| try: | |
| # Set DNS servers | |
| import os | |
| os.environ['PYTHONHTTPSVERIFY'] = '0' | |
| # Create unverified SSL context | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| # Set socket timeout | |
| socket.setdefaulttimeout(30) | |
| except Exception as e: | |
| print(f"DNS fix warning: {e}") | |
| # Call the fix before any downloads | |
| fix_dns_resolution() | |
| # Alternative download function using subprocess | |
| def download_video_robust(video_url, output_dir): | |
| """More robust video download using subprocess""" | |
| try: | |
| # Ensure directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Use yt-dlp with specific options for containerized environments | |
| cmd = [ | |
| 'yt-dlp', | |
| '--no-check-certificates', | |
| '--geo-bypass', | |
| '--no-warnings', | |
| '--output', f'{output_dir}/%(title)s.%(ext)s', | |
| '--format', 'best[height<=720]', | |
| video_url | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) | |
| if result.returncode != 0: | |
| raise Exception(f"yt-dlp failed: {result.stderr}") | |
| # Find the downloaded file | |
| for file in os.listdir(output_dir): | |
| if file.endswith(('.mp4', '.webm', '.mkv')): | |
| return os.path.join(output_dir, file) | |
| raise Exception("No video file found after download") | |
| except subprocess.TimeoutExpired: | |
| raise Exception("Download timeout - video may be too long or network issues") | |
| except Exception as e: | |
| raise Exception(f"Download failed: {str(e)}") | |
| # ------------------------------- | |
| # 2. Preprocessing + Storage | |
| # ------------------------------- | |
| def preprocess_and_store(youtube_url: str): | |
| """Download video, extract frames+metadata, embed & store in LanceDB""" | |
| try: | |
| video_url = youtube_url | |
| ''' | |
| if os.getenv("SPACE_ID"): | |
| video_dir = "/tmp/videos/video1" | |
| else: | |
| video_dir = "./shared_data/videos/video1" | |
| ''' | |
| video_dir = "src/shared_data/videos/video1" | |
| # Use the robust download function | |
| #video_filepath = download_video_robust(video_url, video_dir) | |
| video_filepath = "src/shared_data/videos/video1/Welcome back to Planet Earth.mp4" | |
| # Try to download subtitle, but don't fail if it doesn't work | |
| try: | |
| #video_transcript_filepath = download_youtube_subtitle(video_url, video_dir) | |
| video_transcript_filepath = "src/shared_data/videos/video1/generated_captions.vtt" | |
| except Exception as e: | |
| print(f"Warning: Could not download subtitles: {e}") | |
| # Create empty transcript file | |
| video_transcript_filepath = os.path.join(video_dir, "empty.vtt") | |
| with open(video_transcript_filepath, 'w') as f: | |
| f.write("WEBVTT\n\n00:00:00.000 --> 00:00:10.000\nNo transcript available\n") | |
| extracted_frames_path = os.path.join(video_dir, 'extracted_frame') | |
| # Create output folders | |
| Path(extracted_frames_path).mkdir(parents=True, exist_ok=True) | |
| Path(video_dir).mkdir(parents=True, exist_ok=True) | |
| # Extract frames and metadata | |
| metadatas = extract_and_save_frames_and_metadata( | |
| video_filepath, | |
| video_transcript_filepath, | |
| extracted_frames_path, | |
| video_dir, | |
| ) | |
| # Process transcripts and images | |
| video_trans = [vid['transcript'] for vid in metadatas] | |
| video_img_path = [vid['extracted_frame_path'] for vid in metadatas] | |
| n = 7 | |
| updated_video_trans = [ | |
| ' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else | |
| ' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans)) | |
| ] | |
| # Update metadata with new transcripts | |
| for i in range(len(updated_video_trans)): | |
| metadatas[i]['transcript'] = updated_video_trans[i] | |
| # Store in vector database | |
| _ = MultimodalLanceDB.from_text_image_pairs( | |
| texts=updated_video_trans, | |
| image_paths=video_img_path, | |
| embedding=embedder, | |
| metadatas=metadatas, | |
| connection=db, | |
| table_name=TBL_NAME, | |
| mode="overwrite", | |
| ) | |
| return f"β Video processed and stored: {youtube_url}" | |
| except Exception as e: | |
| return f"β Error processing video: {str(e)}" | |
| # ------------------------------- | |
| # 3. Retrieval + Prompt Functions | |
| # ------------------------------- | |
| vectorstore = MultimodalLanceDB( | |
| uri=LANCEDB_HOST_FILE, | |
| embedding=embedder, | |
| table_name=TBL_NAME | |
| ) | |
| retriever_module = vectorstore.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 3} | |
| ) | |
| def prompt_processing(input): | |
| retrieved_results = input["retrieved_results"] | |
| user_query = input["user_query"] | |
| if not retrieved_results: | |
| return {"prompt": "No relevant content found.", "frame_path": None} | |
| retrieved_results = retrieved_results[0] | |
| prompt_template = ( | |
| "The transcript associated with the image is '{transcript}'. " | |
| "{user_query}" | |
| ) | |
| retrieved_metadata = retrieved_results.metadata | |
| transcript = retrieved_metadata["transcript"] | |
| frame_path = retrieved_metadata["extracted_frame_path"] | |
| return { | |
| "prompt": prompt_template.format(transcript=transcript, user_query=user_query), | |
| "frame_path": frame_path, | |
| } | |
| def lvlm_inference(input): | |
| try: | |
| # get the retrieved results and user's query | |
| lvlm_prompt = input['prompt'] | |
| frame_path = input['frame_path'] | |
| if frame_path is None: | |
| return "No relevant frame found.", None | |
| # Retrieve the API key from environment variables | |
| api_key = os.getenv("MISTRAL_API_KEY") | |
| if not api_key: | |
| return "β MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path | |
| # Initialize the Mistral client | |
| client = Mistral(api_key=api_key) | |
| base64_image = encode_image(frame_path) | |
| # Define the messages for the chat | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": lvlm_prompt | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": f"data:image/jpeg;base64,{base64_image}" | |
| } | |
| ] | |
| } | |
| ] | |
| # Get the chat response | |
| chat_response = client.chat.complete( | |
| model="pixtral-12b-2409", | |
| messages=messages | |
| ) | |
| return chat_response.choices[0].message.content, frame_path | |
| except Exception as e: | |
| return f"β Error in inference: {str(e)}", frame_path | |
| # LangChain Runnable chain | |
| prompt_processing_module = RunnableLambda(prompt_processing) | |
| lvlm_inference_module = RunnableLambda(lvlm_inference) | |
| mm_rag_chain = ( | |
| RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()}) | |
| | prompt_processing_module | |
| | lvlm_inference_module | |
| ) | |
| # ------------------------------- | |
| # 4. Chat API for Gradio | |
| # ------------------------------- | |
| video_loaded = False | |
| def load_video(youtube_url): | |
| global video_loaded | |
| try: | |
| if not youtube_url or not youtube_url.strip(): | |
| return "β Please enter a valid YouTube URL" | |
| status = preprocess_and_store(youtube_url) | |
| if status.startswith("β "): | |
| video_loaded = True | |
| else: | |
| video_loaded = False | |
| return status | |
| except Exception as e: | |
| video_loaded = False | |
| return f"β Unexpected error: {str(e)}" | |
| def chat_interface(message, history): | |
| if not video_loaded: | |
| return "", history + [(message, "β Please load a video first in the 'Load Video' tab.")], None | |
| if not message.strip(): | |
| return "", history, None | |
| try: | |
| final_text_response, frame_path = mm_rag_chain.invoke(message) | |
| history.append((message, final_text_response)) | |
| # Load and return the image | |
| retrieved_image = None | |
| if frame_path: | |
| try: | |
| retrieved_image = Image.open(frame_path) | |
| except Exception as e: | |
| print(f"Error loading image: {e}") | |
| return "", history, retrieved_image | |
| except Exception as e: | |
| error_msg = f"β Error processing query: {str(e)}" | |
| history.append((message, error_msg)) | |
| return "", history, None | |
| def clear_chat(): | |
| return [], None | |
| # ------------------------------- | |
| # 5. Enhanced Gradio Interface | |
| # ------------------------------- | |
| with gr.Blocks( | |
| title="Multimodal RAG Video Chat", | |
| theme=gr.themes.Default() | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π¬ Multimodal RAG Video Chat | |
| Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model! | |
| β οΈ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work. | |
| """) | |
| with gr.Tab("1. Load Video"): | |
| with gr.Column(): | |
| youtube_url = gr.Textbox( | |
| label="YouTube URL", | |
| value="https://www.youtube.com/watch?v=7Hcg-rLYwdM", | |
| interactive=False, | |
| lines=1, | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| load_btn = gr.Button("π Process Video", variant="primary", scale=1) | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| load_btn.click( | |
| fn=load_video, | |
| inputs=youtube_url, | |
| outputs=status, | |
| show_progress=True | |
| ) | |
| with gr.Tab("2. Chat with Video"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Chat about the video", | |
| height=500 | |
| ) | |
| with gr.Column(scale=1): | |
| retrieved_image = gr.Image( | |
| label="Retrieved Frame", | |
| height=400, | |
| show_label=True, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| msg = gr.Textbox( | |
| label="Your question", | |
| placeholder="Ask something about the video content...", | |
| lines=2, | |
| container=False | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| send_btn = gr.Button("π€ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| # Event handlers | |
| msg.submit( | |
| fn=chat_interface, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot, retrieved_image], | |
| show_progress=True | |
| ) | |
| send_btn.click( | |
| fn=chat_interface, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot, retrieved_image], | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, retrieved_image] | |
| ) | |
| with gr.Tab("π Instructions"): | |
| gr.Markdown(""" | |
| ## How to use this Multimodal RAG system: | |
| ### π§ Setup: | |
| 1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings | |
| 2. This app uses Pixtral-12B for vision-language understanding | |
| ### π₯ Load Video: | |
| For this demo, we **use a fixed pre-downloaded YouTube video**: | |
| π [https://www.youtube.com/watch?v=7Hcg-rLYwdM](https://www.youtube.com/watch?v=7Hcg-rLYwdM) | |
| Due to Hugging Face free Space network restrictions, direct downloads from YouTube are disabled. | |
| All processing in this demo is based on this preloaded video. | |
| ### π¬ Chat with Video: | |
| 1. Go to the "Chat with Video" tab | |
| 2. Ask questions about the video content | |
| 3. The system will retrieve the most relevant frame and provide answers | |
| 4. The retrieved frame will be displayed on the right side | |
| ## β¨ Features: | |
| - π₯ **Automatic YouTube Processing**: Downloads and processes YouTube videos | |
| - π§ **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding | |
| - πΎ **Vector Storage**: Stores data in LanceDB for fast similarity search | |
| - π€ **Vision-Language AI**: Powered by Mistral's Pixtral model | |
| - πΌοΈ **Visual Context**: Shows relevant video frames alongside responses | |
| - π **Real-time Processing**: Fast retrieval and inference | |
| ## β οΈ Limitations: | |
| - Works with publicly accessible YouTube videos only | |
| - Processing time depends on video length | |
| - Requires stable internet connection for video download | |
| - API rate limits apply based on Mistral usage | |
| ## π οΈ Technical Stack: | |
| - **Embeddings**: BridgeTower (multimodal) | |
| - **Vector DB**: LanceDB | |
| - **Vision-Language Model**: Pixtral-12B | |
| - **Framework**: LangChain + Gradio | |
| """) | |
| with gr.Tab("π About"): | |
| gr.Markdown(""" | |
| ## Multimodal RAG Video Chat System | |
| This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content. | |
| ### Architecture: | |
| 1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps | |
| 2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content | |
| 3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search | |
| 4. **Retrieval**: Finds the most relevant video segments based on user queries | |
| 5. **Generation**: Uses Pixtral vision-language model to generate contextual responses | |
| ### Built with: | |
| - **Gradio**: For the web interface | |
| - **LangChain**: For orchestrating the RAG pipeline | |
| - **LanceDB**: For vector storage and retrieval | |
| - **BridgeTower**: For multimodal embeddings | |
| - **Mistral Pixtral**: For vision-language understanding | |
| --- | |
| π‘ **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video. | |
| """) | |
| # ------------------------------- | |
| # 6. Launch Configuration | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| print('π Starting Multimodal RAG Video Chat App...') | |
| # Check for required environment variables | |
| if not os.getenv("MISTRAL_API_KEY"): | |
| print("β οΈ WARNING: MISTRAL_API_KEY not found in environment variables") | |
| print(" Please set this in your HuggingFace Space settings") | |
| # Launch with appropriate settings for HF Spaces | |
| if IS_SPACES: | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=7860) # Use default settings for HF Spaces | |
| else: | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |