Upload folder using huggingface_hub
Browse files- .gitignore +8 -0
- README.md +178 -0
- gradio_utils.py +483 -0
- lrn_vector_embeddings.py +111 -0
- main_demo.py +142 -0
- requirements.txt +22 -0
- s2_download_data.py +49 -0
- s3_data_to_vector_embedding.py +61 -0
- s4_calculate_distance.py +83 -0
- s5-how-to-umap.py +137 -0
- s6_prepare_video_input.py +90 -0
- s7_store_in_rag.py +105 -0
- upload_huggingface.py +8 -0
- utility.py +693 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
myenv
|
| 2 |
+
__pycache__
|
| 3 |
+
mm_rag/*
|
| 4 |
+
shared_data
|
| 5 |
+
.gradio
|
| 6 |
+
.env
|
| 7 |
+
.venv
|
| 8 |
+
.github
|
README.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: multimodel-rag-chat-with-videos
|
| 3 |
+
app_file: main_demo.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.17.1
|
| 6 |
+
---
|
| 7 |
+
# ReArchitecture Multimodal RAG System Pipeline Journey
|
| 8 |
+
I ported it locally and isolated each concept into a step as Python runnable
|
| 9 |
+
It is simplified, refactored and bug-fixed now.
|
| 10 |
+
I migrated from Prediction Guard to HuggingFace.
|
| 11 |
+
|
| 12 |
+
[**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
|
| 13 |
+
|
| 14 |
+
### A multimodal AI system should be able to understand both text and video content.
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Step 1 - Learn Gradio (UI) (30 mins)
|
| 19 |
+
|
| 20 |
+
Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
|
| 21 |
+
|
| 22 |
+
### Key Concepts:
|
| 23 |
+
- **fn**: The function wrapped by the UI.
|
| 24 |
+
- **inputs**: The Gradio components used for input (should match function arguments).
|
| 25 |
+
- **outputs**: The Gradio components used for output (should match return values).
|
| 26 |
+
|
| 27 |
+
📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
|
| 28 |
+
|
| 29 |
+
Gradio includes **30+ built-in components**.
|
| 30 |
+
|
| 31 |
+
💡 **Tip**: For `inputs` and `outputs`, you can pass either:
|
| 32 |
+
- The **component name** as a string (e.g., `"textbox"`)
|
| 33 |
+
- An **instance of the component class** (e.g., `gr.Textbox()`)
|
| 34 |
+
|
| 35 |
+
### Sharing Your Demo
|
| 36 |
+
```python
|
| 37 |
+
demo.launch(share=True) # Share your demo with just one extra parameter.
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Gradio Advanced Features
|
| 41 |
+
|
| 42 |
+
### **Gradio.Blocks**
|
| 43 |
+
Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
|
| 44 |
+
- Arrange components freely on the page.
|
| 45 |
+
- Handle multiple data flows.
|
| 46 |
+
- Use outputs as inputs for other components.
|
| 47 |
+
- Dynamically update components based on user interaction.
|
| 48 |
+
|
| 49 |
+
### **Gradio.ChatInterface**
|
| 50 |
+
- Always set `type="messages"` in `gr.ChatInterface`.
|
| 51 |
+
- The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
|
| 52 |
+
- For more UI flexibility, use `gr.ChatBot`.
|
| 53 |
+
- `gr.ChatInterface` supports **Markdown** (not tested yet).
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
|
| 58 |
+
|
| 59 |
+
Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
|
| 60 |
+
|
| 61 |
+
### Measuring Similarity
|
| 62 |
+
- **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
|
| 63 |
+
- **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
|
| 64 |
+
|
| 65 |
+
### Converting to 2D for Visualization
|
| 66 |
+
- **UMAP** reduces 512D embeddings to **2D for display purposes**.
|
| 67 |
+
|
| 68 |
+
## Preprocessing Videos for Multimodal RAG
|
| 69 |
+
|
| 70 |
+
### **Case 1: WEBVTT → Extracting Text Segments from Video**
|
| 71 |
+
- Converts video + text into structured metadata.
|
| 72 |
+
- Splits content into multiple segments.
|
| 73 |
+
|
| 74 |
+
### **Case 2: Whisper (Small) → Video Only**
|
| 75 |
+
- Extracts **audio** → `model.transcribe()`.
|
| 76 |
+
- Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
|
| 77 |
+
- Uses **Case 1** processing.
|
| 78 |
+
|
| 79 |
+
### **Case 3: LvLM → Video + Silent/Music Extraction**
|
| 80 |
+
- Uses **Llava (LvLM model)** for **frame-based captioning**.
|
| 81 |
+
- Encodes each frame as a **Base64 image**.
|
| 82 |
+
- Extracts context and captions from video frames.
|
| 83 |
+
- Uses **Case 1** processing.
|
| 84 |
+
|
| 85 |
+
# Step 4 - What is LLaVA?
|
| 86 |
+
LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
|
| 87 |
+
|
| 88 |
+
# Step 5 - what is a vector Store?
|
| 89 |
+
A vector store is a specialized database designed to:
|
| 90 |
+
|
| 91 |
+
- Store and manage high-dimensional vector data efficiently
|
| 92 |
+
- Perform similarity-based searches where K=1 returns the most similar result
|
| 93 |
+
|
| 94 |
+
- In LanceDB specifically, store multiple data types:
|
| 95 |
+
. Text content (captions)
|
| 96 |
+
. Image file paths
|
| 97 |
+
. Metadata
|
| 98 |
+
. Vector embeddings
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
| 102 |
+
texts=updated_vid1_trans+vid2_trans,
|
| 103 |
+
image_paths=vid1_img_path+vid2_img_path,
|
| 104 |
+
embedding=BridgeTowerEmbeddings(),
|
| 105 |
+
metadatas=vid1_metadata+vid2_metadata,
|
| 106 |
+
connection=db,
|
| 107 |
+
table_name=TBL_NAME,
|
| 108 |
+
mode="overwrite",
|
| 109 |
+
)
|
| 110 |
+
```
|
| 111 |
+
# Gotchas and Solutions
|
| 112 |
+
Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
|
| 113 |
+
Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
|
| 114 |
+
Model Size: BridgeTower model requires ~3.5GB download
|
| 115 |
+
Image Downloads: Some Flickr images may be unavailable; implement robust error handling
|
| 116 |
+
Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
|
| 117 |
+
Install from git+https://github.com/openai/whisper.git
|
| 118 |
+
|
| 119 |
+
# Install ffmepg using brew
|
| 120 |
+
```bash
|
| 121 |
+
brew install ffmpeg
|
| 122 |
+
brew link ffmpeg
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Learning and Skills
|
| 127 |
+
|
| 128 |
+
## Technical Skills:
|
| 129 |
+
|
| 130 |
+
Basic Machine learning and deep learning
|
| 131 |
+
Vector embeddings and similarity search
|
| 132 |
+
Multimodal data processing
|
| 133 |
+
|
| 134 |
+
## Framework & Library Expertise:
|
| 135 |
+
|
| 136 |
+
Hugging Face Transformers
|
| 137 |
+
Gradio UI development
|
| 138 |
+
LangChain integration (Basic)
|
| 139 |
+
PyTorch basics
|
| 140 |
+
LanceDB vector storage
|
| 141 |
+
|
| 142 |
+
## AI/ML Concepts:
|
| 143 |
+
|
| 144 |
+
Multimodal RAG system architecture
|
| 145 |
+
Vector embeddings and similarity search
|
| 146 |
+
Large Language Models (LLaVA)
|
| 147 |
+
Image-text pair processing
|
| 148 |
+
Dimensionality reduction techniques
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Multimedia Processing:
|
| 152 |
+
|
| 153 |
+
Video frame extraction
|
| 154 |
+
Audio transcription (Whisper)
|
| 155 |
+
Image processing (PIL)
|
| 156 |
+
Base64 encoding/decoding
|
| 157 |
+
WebVTT handling
|
| 158 |
+
|
| 159 |
+
## System Design:
|
| 160 |
+
|
| 161 |
+
Client-server architecture
|
| 162 |
+
API endpoint design
|
| 163 |
+
Data pipeline construction
|
| 164 |
+
Vector store implementation
|
| 165 |
+
Multimodal system integration
|
| 166 |
+
## Hugging Face
|
| 167 |
+
Remote: hf_origin
|
| 168 |
+
branch:hf_main
|
| 169 |
+
title: Hg Demo
|
| 170 |
+
emoji: 😻
|
| 171 |
+
colorFrom: gray
|
| 172 |
+
colorTo: red
|
| 173 |
+
sdk: gradio
|
| 174 |
+
sdk_version: 5.18.0
|
| 175 |
+
app_file: app.py
|
| 176 |
+
pinned: false
|
| 177 |
+
license: mit
|
| 178 |
+
short_description: 'A space to keep AI work for demo '
|
gradio_utils.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import io
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import dataclasses
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
from enum import auto, Enum
|
| 9 |
+
from typing import List, Tuple, Any
|
| 10 |
+
from utility import prediction_guard_llava_conv
|
| 11 |
+
import lancedb
|
| 12 |
+
from utility import load_json_file
|
| 13 |
+
from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
|
| 14 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
| 15 |
+
from mm_rag.MLM.client import PredictionGuardClient
|
| 16 |
+
from mm_rag.MLM.lvlm import LVLM
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
|
| 19 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
| 20 |
+
from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
|
| 21 |
+
|
| 22 |
+
server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
| 23 |
+
|
| 24 |
+
# function to split video at a timestamp
|
| 25 |
+
def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
|
| 26 |
+
timestamp_in_sec = int(timestamp_in_ms / 1000)
|
| 27 |
+
# create output_video_name folder if not exist:
|
| 28 |
+
Path(output_video_path).mkdir(parents=True, exist_ok=True)
|
| 29 |
+
output_video = os.path.join(output_video_path, output_video_name)
|
| 30 |
+
with VideoFileClip(video_path) as video:
|
| 31 |
+
duration = video.duration
|
| 32 |
+
start_time = max(timestamp_in_sec - play_before_sec, 0)
|
| 33 |
+
end_time = min(timestamp_in_sec + play_after_sec, duration)
|
| 34 |
+
new = video.subclip(start_time, end_time)
|
| 35 |
+
new.write_videofile(output_video, audio_codec='aac')
|
| 36 |
+
return output_video
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
| 40 |
+
|
| 41 |
+
# define default rag_chain
|
| 42 |
+
def get_default_rag_chain():
|
| 43 |
+
# declare host file
|
| 44 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
| 45 |
+
# declare table name
|
| 46 |
+
TBL_NAME = "demo_tbl"
|
| 47 |
+
|
| 48 |
+
# initialize vectorstore
|
| 49 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
| 50 |
+
|
| 51 |
+
# initialize an BridgeTower embedder
|
| 52 |
+
embedder = BridgeTowerEmbeddings()
|
| 53 |
+
|
| 54 |
+
## Creating a LanceDB vector store
|
| 55 |
+
vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
|
| 56 |
+
### creating a retriever for the vector store
|
| 57 |
+
retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
|
| 58 |
+
|
| 59 |
+
# initialize a client as PredictionGuardClien
|
| 60 |
+
client = PredictionGuardClient()
|
| 61 |
+
# initialize LVLM with the given client
|
| 62 |
+
lvlm_inference_module = LVLM(client=client)
|
| 63 |
+
|
| 64 |
+
def prompt_processing(input):
|
| 65 |
+
# get the retrieved results and user's query
|
| 66 |
+
retrieved_results, user_query = input['retrieved_results'], input['user_query']
|
| 67 |
+
# get the first retrieved result by default
|
| 68 |
+
retrieved_result = retrieved_results[0]
|
| 69 |
+
# prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
| 70 |
+
|
| 71 |
+
# get all metadata of the retrieved video segment
|
| 72 |
+
metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
|
| 73 |
+
|
| 74 |
+
# get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
|
| 75 |
+
transcript = metadata_retrieved_video_segment['transcript']
|
| 76 |
+
frame_path = metadata_retrieved_video_segment['extracted_frame_path']
|
| 77 |
+
return {
|
| 78 |
+
'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
|
| 79 |
+
'image' : frame_path,
|
| 80 |
+
'metadata' : metadata_retrieved_video_segment,
|
| 81 |
+
}
|
| 82 |
+
# initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
|
| 83 |
+
prompt_processing_module = RunnableLambda(prompt_processing)
|
| 84 |
+
|
| 85 |
+
# the output of this new chain will be a dictionary
|
| 86 |
+
mm_rag_chain_with_retrieved_image = (
|
| 87 |
+
RunnableParallel({"retrieved_results": retriever_module ,
|
| 88 |
+
"user_query": RunnablePassthrough()})
|
| 89 |
+
| prompt_processing_module
|
| 90 |
+
| RunnableParallel({'final_text_output': lvlm_inference_module,
|
| 91 |
+
'input_to_lvlm' : RunnablePassthrough()})
|
| 92 |
+
)
|
| 93 |
+
return mm_rag_chain_with_retrieved_image
|
| 94 |
+
|
| 95 |
+
class SeparatorStyle(Enum):
|
| 96 |
+
"""Different separator style."""
|
| 97 |
+
SINGLE = auto()
|
| 98 |
+
|
| 99 |
+
@dataclasses.dataclass
|
| 100 |
+
class GradioInstance:
|
| 101 |
+
"""A class that keeps all conversation history."""
|
| 102 |
+
system: str
|
| 103 |
+
roles: List[str]
|
| 104 |
+
messages: List[List[str]]
|
| 105 |
+
offset: int
|
| 106 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 107 |
+
sep: str = "\n"
|
| 108 |
+
sep2: str = None
|
| 109 |
+
version: str = "Unknown"
|
| 110 |
+
path_to_img: str = None
|
| 111 |
+
video_title: str = None
|
| 112 |
+
path_to_video: str = None
|
| 113 |
+
caption: str = None
|
| 114 |
+
mm_rag_chain: Any = None
|
| 115 |
+
|
| 116 |
+
skip_next: bool = False
|
| 117 |
+
|
| 118 |
+
def _template_caption(self):
|
| 119 |
+
out = ""
|
| 120 |
+
if self.caption is not None:
|
| 121 |
+
out = f"The caption associated with the image is '{self.caption}'. "
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
def get_prompt_for_rag(self):
|
| 125 |
+
messages = self.messages
|
| 126 |
+
assert len(messages) == 2, "length of current conversation should be 2"
|
| 127 |
+
assert messages[1][1] is None, "the first response message of current conversation should be None"
|
| 128 |
+
ret = messages[0][1]
|
| 129 |
+
return ret
|
| 130 |
+
|
| 131 |
+
def get_conversation_for_lvlm(self):
|
| 132 |
+
pg_conv = prediction_guard_llava_conv.copy()
|
| 133 |
+
image_path = self.path_to_img
|
| 134 |
+
b64_img = encode_image(image_path)
|
| 135 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 136 |
+
if msg is None:
|
| 137 |
+
break
|
| 138 |
+
if i == 0:
|
| 139 |
+
pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
|
| 140 |
+
elif i == len(self.messages[self.offset:]) - 2:
|
| 141 |
+
pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
|
| 142 |
+
else:
|
| 143 |
+
pg_conv.append_message(role, [msg])
|
| 144 |
+
return pg_conv
|
| 145 |
+
|
| 146 |
+
def append_message(self, role, message):
|
| 147 |
+
self.messages.append([role, message])
|
| 148 |
+
|
| 149 |
+
def get_images(self, return_pil=False):
|
| 150 |
+
images = []
|
| 151 |
+
if self.path_to_img is not None:
|
| 152 |
+
path_to_image = self.path_to_img
|
| 153 |
+
images.append(path_to_image)
|
| 154 |
+
return images
|
| 155 |
+
|
| 156 |
+
def to_gradio_chatbot(self):
|
| 157 |
+
ret = []
|
| 158 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 159 |
+
if i % 2 == 0:
|
| 160 |
+
if type(msg) is tuple:
|
| 161 |
+
import base64
|
| 162 |
+
from io import BytesIO
|
| 163 |
+
msg, image, image_process_mode = msg
|
| 164 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 165 |
+
aspect_ratio = max_hw / min_hw
|
| 166 |
+
max_len, min_len = 800, 400
|
| 167 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 168 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 169 |
+
W, H = image.size
|
| 170 |
+
if H > W:
|
| 171 |
+
H, W = longest_edge, shortest_edge
|
| 172 |
+
else:
|
| 173 |
+
H, W = shortest_edge, longest_edge
|
| 174 |
+
image = image.resize((W, H))
|
| 175 |
+
buffered = BytesIO()
|
| 176 |
+
image.save(buffered, format="JPEG")
|
| 177 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 178 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 179 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
| 180 |
+
ret.append([msg, None])
|
| 181 |
+
else:
|
| 182 |
+
ret.append([msg, None])
|
| 183 |
+
else:
|
| 184 |
+
ret[-1][-1] = msg
|
| 185 |
+
return ret
|
| 186 |
+
|
| 187 |
+
def copy(self):
|
| 188 |
+
return GradioInstance(
|
| 189 |
+
system=self.system,
|
| 190 |
+
roles=self.roles,
|
| 191 |
+
messages=[[x, y] for x, y in self.messages],
|
| 192 |
+
offset=self.offset,
|
| 193 |
+
sep_style=self.sep_style,
|
| 194 |
+
sep=self.sep,
|
| 195 |
+
sep2=self.sep2,
|
| 196 |
+
version=self.version,
|
| 197 |
+
mm_rag_chain=self.mm_rag_chain,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def dict(self):
|
| 201 |
+
return {
|
| 202 |
+
"system": self.system,
|
| 203 |
+
"roles": self.roles,
|
| 204 |
+
"messages": self.messages,
|
| 205 |
+
"offset": self.offset,
|
| 206 |
+
"sep": self.sep,
|
| 207 |
+
"sep2": self.sep2,
|
| 208 |
+
"path_to_img": self.path_to_img,
|
| 209 |
+
"video_title" : self.video_title,
|
| 210 |
+
"path_to_video": self.path_to_video,
|
| 211 |
+
"caption" : self.caption,
|
| 212 |
+
}
|
| 213 |
+
def get_path_to_subvideos(self):
|
| 214 |
+
if self.video_title is not None and self.path_to_img is not None:
|
| 215 |
+
info = video_helper_map[self.video_title]
|
| 216 |
+
path = info['path']
|
| 217 |
+
prefix = info['prefix']
|
| 218 |
+
vid_index = self.path_to_img.split('/')[-1]
|
| 219 |
+
vid_index = vid_index.split('_')[-1]
|
| 220 |
+
vid_index = vid_index.replace('.jpg', '')
|
| 221 |
+
ret = f"{prefix}{vid_index}.mp4"
|
| 222 |
+
ret = os.path.join(path, ret)
|
| 223 |
+
return ret
|
| 224 |
+
elif self.path_to_video is not None:
|
| 225 |
+
return self.path_to_video
|
| 226 |
+
return None
|
| 227 |
+
|
| 228 |
+
def get_gradio_instance(mm_rag_chain=None):
|
| 229 |
+
if mm_rag_chain is None:
|
| 230 |
+
mm_rag_chain = get_default_rag_chain()
|
| 231 |
+
|
| 232 |
+
instance = GradioInstance(
|
| 233 |
+
system="",
|
| 234 |
+
roles=prediction_guard_llava_conv.roles,
|
| 235 |
+
messages=[],
|
| 236 |
+
offset=0,
|
| 237 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 238 |
+
sep="\n",
|
| 239 |
+
path_to_img=None,
|
| 240 |
+
video_title=None,
|
| 241 |
+
caption=None,
|
| 242 |
+
mm_rag_chain=mm_rag_chain,
|
| 243 |
+
)
|
| 244 |
+
return instance
|
| 245 |
+
|
| 246 |
+
gr.set_static_paths(paths=["./assets/"])
|
| 247 |
+
theme = gr.themes.Base(
|
| 248 |
+
primary_hue=gr.themes.Color(
|
| 249 |
+
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
|
| 250 |
+
secondary_hue=gr.themes.Color(
|
| 251 |
+
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
|
| 252 |
+
).set(
|
| 253 |
+
body_background_fill_dark='*primary_950',
|
| 254 |
+
body_text_color_dark='*neutral_300',
|
| 255 |
+
border_color_accent='*primary_700',
|
| 256 |
+
border_color_accent_dark='*neutral_800',
|
| 257 |
+
block_background_fill_dark='*primary_950',
|
| 258 |
+
block_border_width='2px',
|
| 259 |
+
block_border_width_dark='2px',
|
| 260 |
+
button_primary_background_fill_dark='*primary_500',
|
| 261 |
+
button_primary_border_color_dark='*primary_500'
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
css='''
|
| 265 |
+
@font-face {
|
| 266 |
+
font-family: IntelOne;
|
| 267 |
+
src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
|
| 268 |
+
}
|
| 269 |
+
.gradio-container {background-color: #0a0c2b}
|
| 270 |
+
table {
|
| 271 |
+
border-collapse: collapse;
|
| 272 |
+
border: none;
|
| 273 |
+
}
|
| 274 |
+
'''
|
| 275 |
+
|
| 276 |
+
## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
|
| 277 |
+
|
| 278 |
+
# html_title = '''
|
| 279 |
+
# <table style="bordercolor=#0a0c2b; border=0">
|
| 280 |
+
# <tr style="height:150px; border:0">
|
| 281 |
+
# <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
|
| 282 |
+
# <td style="vertical-align:bottom; border:0">
|
| 283 |
+
# <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
|
| 284 |
+
# Multimodal RAG:
|
| 285 |
+
# <br>
|
| 286 |
+
# Chat with Videos
|
| 287 |
+
# </p>
|
| 288 |
+
# </td>
|
| 289 |
+
# <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
|
| 290 |
+
|
| 291 |
+
# <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
|
| 292 |
+
# <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
|
| 293 |
+
# </tr>
|
| 294 |
+
# </table>
|
| 295 |
+
|
| 296 |
+
# '''
|
| 297 |
+
|
| 298 |
+
html_title = '''
|
| 299 |
+
<table style="bordercolor=#0a0c2b; border=0">
|
| 300 |
+
<tr style="height:150px; border:0">
|
| 301 |
+
<td style="border:0"><img src="/file=./assets/header.png"></td>
|
| 302 |
+
</tr>
|
| 303 |
+
</table>
|
| 304 |
+
|
| 305 |
+
'''
|
| 306 |
+
|
| 307 |
+
#<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
|
| 308 |
+
dropdown_list = [
|
| 309 |
+
"What is the name of one of the astronauts?",
|
| 310 |
+
"An astronaut's spacewalk",
|
| 311 |
+
"What does the astronaut say?",
|
| 312 |
+
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
no_change_btn = gr.Button()
|
| 316 |
+
enable_btn = gr.Button(interactive=True)
|
| 317 |
+
disable_btn = gr.Button(interactive=False)
|
| 318 |
+
|
| 319 |
+
def clear_history(state, request: gr.Request):
|
| 320 |
+
state = get_gradio_instance(state.mm_rag_chain)
|
| 321 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
|
| 322 |
+
|
| 323 |
+
def add_text(state, text, request: gr.Request):
|
| 324 |
+
if len(text) <= 0 :
|
| 325 |
+
state.skip_next = True
|
| 326 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
|
| 327 |
+
|
| 328 |
+
text = text[:1536] # Hard cut-off
|
| 329 |
+
|
| 330 |
+
state.append_message(state.roles[0], text)
|
| 331 |
+
state.append_message(state.roles[1], None)
|
| 332 |
+
state.skip_next = False
|
| 333 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
|
| 334 |
+
|
| 335 |
+
def http_bot(
|
| 336 |
+
state, request: gr.Request
|
| 337 |
+
):
|
| 338 |
+
start_tstamp = time.time()
|
| 339 |
+
|
| 340 |
+
if state.skip_next:
|
| 341 |
+
# This generate call is skipped due to invalid inputs
|
| 342 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
| 343 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
if len(state.messages) == state.offset + 2:
|
| 347 |
+
# First round of conversation
|
| 348 |
+
new_state = get_gradio_instance(state.mm_rag_chain)
|
| 349 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
| 350 |
+
new_state.append_message(new_state.roles[1], None)
|
| 351 |
+
state = new_state
|
| 352 |
+
|
| 353 |
+
all_images = state.get_images(return_pil=False)
|
| 354 |
+
|
| 355 |
+
# Make requests
|
| 356 |
+
is_very_first_query = True
|
| 357 |
+
if len(all_images) == 0:
|
| 358 |
+
# first query need to do RAG
|
| 359 |
+
# Construct prompt
|
| 360 |
+
prompt_or_conversation = state.get_prompt_for_rag()
|
| 361 |
+
else:
|
| 362 |
+
# subsequence queries, no need to do Retrieval
|
| 363 |
+
is_very_first_query = False
|
| 364 |
+
prompt_or_conversation = state.get_conversation_for_lvlm()
|
| 365 |
+
|
| 366 |
+
if is_very_first_query:
|
| 367 |
+
executor = state.mm_rag_chain
|
| 368 |
+
else:
|
| 369 |
+
executor = lvlm_inference_with_conversation
|
| 370 |
+
|
| 371 |
+
state.messages[-1][-1] = "▌"
|
| 372 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
| 373 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
if is_very_first_query:
|
| 377 |
+
# get response by invoke executor chain
|
| 378 |
+
response = executor.invoke(prompt_or_conversation)
|
| 379 |
+
message = response['final_text_output']
|
| 380 |
+
if 'metadata' in response['input_to_lvlm']:
|
| 381 |
+
metadata = response['input_to_lvlm']['metadata']
|
| 382 |
+
if (state.path_to_img is None
|
| 383 |
+
and 'input_to_lvlm' in response
|
| 384 |
+
and 'image' in response['input_to_lvlm']
|
| 385 |
+
):
|
| 386 |
+
state.path_to_img = response['input_to_lvlm']['image']
|
| 387 |
+
|
| 388 |
+
if state.path_to_video is None and 'video_path' in metadata:
|
| 389 |
+
video_path = metadata['video_path']
|
| 390 |
+
mid_time_ms = metadata['mid_time_ms']
|
| 391 |
+
splited_video_path = split_video(video_path, mid_time_ms)
|
| 392 |
+
state.path_to_video = splited_video_path
|
| 393 |
+
|
| 394 |
+
if state.caption is None and 'transcript' in metadata:
|
| 395 |
+
state.caption = metadata['transcript']
|
| 396 |
+
else:
|
| 397 |
+
raise ValueError("Response's format is changed")
|
| 398 |
+
else:
|
| 399 |
+
# get the response message by directly call PredictionGuardAPI
|
| 400 |
+
message = executor(prompt_or_conversation)
|
| 401 |
+
|
| 402 |
+
except Exception as e:
|
| 403 |
+
print(e)
|
| 404 |
+
state.messages[-1][-1] = server_error_msg
|
| 405 |
+
yield (state, state.to_gradio_chatbot(), None) + (
|
| 406 |
+
enable_btn,
|
| 407 |
+
)
|
| 408 |
+
return
|
| 409 |
+
|
| 410 |
+
state.messages[-1][-1] = message
|
| 411 |
+
path_to_sub_videos = state.get_path_to_subvideos()
|
| 412 |
+
# path_to_image = state.path_to_img
|
| 413 |
+
# caption = state.caption
|
| 414 |
+
# # print(path_to_sub_videos)
|
| 415 |
+
# # print(path_to_image)
|
| 416 |
+
# # print('caption: ', caption)
|
| 417 |
+
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
|
| 418 |
+
|
| 419 |
+
finish_tstamp = time.time()
|
| 420 |
+
return
|
| 421 |
+
|
| 422 |
+
def get_demo(rag_chain=None):
|
| 423 |
+
if rag_chain is None:
|
| 424 |
+
rag_chain = get_default_rag_chain()
|
| 425 |
+
|
| 426 |
+
with gr.Blocks(theme=theme, css=css) as demo:
|
| 427 |
+
# gr.Markdown(description)
|
| 428 |
+
instance = get_gradio_instance(rag_chain)
|
| 429 |
+
state = gr.State(instance)
|
| 430 |
+
demo.load(
|
| 431 |
+
None,
|
| 432 |
+
None,
|
| 433 |
+
js="""
|
| 434 |
+
() => {
|
| 435 |
+
const params = new URLSearchParams(window.location.search);
|
| 436 |
+
if (!params.has('__theme')) {
|
| 437 |
+
params.set('__theme', 'dark');
|
| 438 |
+
window.location.search = params.toString();
|
| 439 |
+
}
|
| 440 |
+
}""",
|
| 441 |
+
)
|
| 442 |
+
gr.HTML(value=html_title)
|
| 443 |
+
with gr.Row():
|
| 444 |
+
with gr.Column(scale=4):
|
| 445 |
+
video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
|
| 446 |
+
with gr.Column(scale=7):
|
| 447 |
+
chatbot = gr.Chatbot(
|
| 448 |
+
elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
|
| 449 |
+
)
|
| 450 |
+
with gr.Row():
|
| 451 |
+
with gr.Column(scale=8):
|
| 452 |
+
# textbox.render()
|
| 453 |
+
textbox = gr.Dropdown(
|
| 454 |
+
dropdown_list,
|
| 455 |
+
allow_custom_value=True,
|
| 456 |
+
# show_label=False,
|
| 457 |
+
# container=False,
|
| 458 |
+
label="Query",
|
| 459 |
+
info="Enter your query here or choose a sample from the dropdown list!"
|
| 460 |
+
)
|
| 461 |
+
with gr.Column(scale=1, min_width=50):
|
| 462 |
+
submit_btn = gr.Button(
|
| 463 |
+
value="Send", variant="primary", interactive=True
|
| 464 |
+
)
|
| 465 |
+
with gr.Row(elem_id="buttons") as button_row:
|
| 466 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
| 467 |
+
|
| 468 |
+
btn_list = [clear_btn]
|
| 469 |
+
|
| 470 |
+
clear_btn.click(
|
| 471 |
+
clear_history, [state], [state, chatbot, textbox, video] + btn_list
|
| 472 |
+
)
|
| 473 |
+
submit_btn.click(
|
| 474 |
+
add_text,
|
| 475 |
+
[state, textbox],
|
| 476 |
+
[state, chatbot, textbox,] + btn_list,
|
| 477 |
+
).then(
|
| 478 |
+
http_bot,
|
| 479 |
+
[state],
|
| 480 |
+
[state, chatbot, video] + btn_list,
|
| 481 |
+
)
|
| 482 |
+
return demo
|
| 483 |
+
|
lrn_vector_embeddings.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from numpy.linalg import norm
|
| 5 |
+
import cv2
|
| 6 |
+
from io import StringIO, BytesIO
|
| 7 |
+
from umap import UMAP
|
| 8 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import base64
|
| 12 |
+
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM
|
| 13 |
+
import requests
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
| 18 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
| 19 |
+
|
| 20 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
| 21 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
| 22 |
+
|
| 23 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
| 24 |
+
cap3='a cat laying down stretched out near a laptop'
|
| 25 |
+
|
| 26 |
+
img1 = {
|
| 27 |
+
'flickr_url': url1,
|
| 28 |
+
'caption': cap1,
|
| 29 |
+
'image_path' : './shared_data/motorcycle_1.jpg'
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
img2 = {
|
| 33 |
+
'flickr_url': url2,
|
| 34 |
+
'caption': cap2,
|
| 35 |
+
'image_path' : './shared_data/motorcycle_2.jpg'
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
img3 = {
|
| 39 |
+
'flickr_url' : url3,
|
| 40 |
+
'caption': cap3,
|
| 41 |
+
'image_path' : './shared_data/cat_1.jpg'
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def bt_embeddings_from_local(text, image):
|
| 45 |
+
|
| 46 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
| 47 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
| 48 |
+
|
| 49 |
+
processed_inputs = processor(image, text, padding=True, return_tensors="pt")
|
| 50 |
+
|
| 51 |
+
#inputs = processor(prompt, base64_image, padding=True, return_tensors="pt")
|
| 52 |
+
outputs = model(**processed_inputs)
|
| 53 |
+
|
| 54 |
+
cross_modal_embeddings = outputs.cross_embeds
|
| 55 |
+
text_embeddings = outputs.text_embeds
|
| 56 |
+
image_embeddings = outputs.image_embeds
|
| 57 |
+
return {
|
| 58 |
+
'cross_modal_embeddings': cross_modal_embeddings,
|
| 59 |
+
'text_embeddings': text_embeddings,
|
| 60 |
+
'image_embeddings': image_embeddings
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def bt_scores_with_image_and_text_retrieval():
|
| 65 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 66 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 67 |
+
texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
|
| 68 |
+
|
| 69 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
| 70 |
+
model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
| 71 |
+
|
| 72 |
+
# forward pass
|
| 73 |
+
scores = dict()
|
| 74 |
+
for text in texts:
|
| 75 |
+
# prepare inputs
|
| 76 |
+
encoding = processor(image, text, return_tensors="pt")
|
| 77 |
+
outputs = model(**encoding)
|
| 78 |
+
scores[text] = outputs.logits[0,1].item()
|
| 79 |
+
return scores
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def bt_with_masked_input():
|
| 83 |
+
url = "http://images.cocodataset.org/val2017/000000360943.jpg"
|
| 84 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 85 |
+
text = "a <mask> looking out of the window"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
| 89 |
+
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
|
| 90 |
+
|
| 91 |
+
# prepare inputs
|
| 92 |
+
encoding = processor(image, text, return_tensors="pt")
|
| 93 |
+
|
| 94 |
+
# forward pass
|
| 95 |
+
outputs = model(**encoding)
|
| 96 |
+
|
| 97 |
+
token_ids = outputs.logits.argmax(dim=-1).squeeze(0).tolist()
|
| 98 |
+
if isinstance(token_ids, list):
|
| 99 |
+
results = processor.tokenizer.decode(token_ids)
|
| 100 |
+
else:
|
| 101 |
+
results = processor.tokenizer.decode([token_ids])
|
| 102 |
+
|
| 103 |
+
print(results)
|
| 104 |
+
return results
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
#res = bt_embeddingsl()
|
| 108 |
+
#print((res['text_embeddings']))
|
| 109 |
+
for img in [img1, img2, img3]:
|
| 110 |
+
embeddings = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
|
| 111 |
+
print(embeddings['cross_modal_embeddings'][0].shape)
|
main_demo.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import ollama
|
| 6 |
+
from utility import download_video, get_transcript_vtt, extract_meta_data
|
| 7 |
+
from mm_rag.embeddings.bridgetower_embeddings import (
|
| 8 |
+
BridgeTowerEmbeddings
|
| 9 |
+
)
|
| 10 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
| 11 |
+
import lancedb
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from utility import load_json_file, display_retrieved_results
|
| 16 |
+
import pyarrow as pa
|
| 17 |
+
|
| 18 |
+
# declare host file
|
| 19 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
| 20 |
+
# declare table name
|
| 21 |
+
TBL_NAME = "demo_tbl"
|
| 22 |
+
# initialize vectorstore
|
| 23 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
| 24 |
+
# initialize an BridgeTower embedder
|
| 25 |
+
embedder = BridgeTowerEmbeddings()
|
| 26 |
+
|
| 27 |
+
vid_dir = "./shared_data/videos/yt_video"
|
| 28 |
+
Path(vid_dir).mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def open_table():
|
| 32 |
+
# open a connection to table TBL_NAME
|
| 33 |
+
tbl = db.open_table(TBL_NAME)
|
| 34 |
+
|
| 35 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
| 36 |
+
# display the first 3 rows of the table
|
| 37 |
+
tbl.to_pandas()[['text', 'image_path']].head(3)
|
| 38 |
+
|
| 39 |
+
def store_in_rag():
|
| 40 |
+
|
| 41 |
+
# load metadata files
|
| 42 |
+
vid_metadata_path = './shared_data/videos/yt_video/metadatas.json'
|
| 43 |
+
vid_metadata = load_json_file(vid_metadata_path)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
vid_subs = [vid['transcript'] for vid in vid_metadata]
|
| 47 |
+
vid_img_path = [vid['extracted_frame_path'] for vid in vid_metadata]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# for video1, we pick n = 7
|
| 51 |
+
n = 7
|
| 52 |
+
updated_vid_subs = [
|
| 53 |
+
' '.join(vid_subs[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
|
| 54 |
+
' '.join(vid_subs[0 : i + int(n/2)]) for i in range(len(vid_subs))
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# also need to update the updated transcripts in metadata
|
| 58 |
+
for i in range(len(updated_vid_subs)):
|
| 59 |
+
vid_metadata[i]['transcript'] = updated_vid_subs[i]
|
| 60 |
+
|
| 61 |
+
# you can pass in mode="append"
|
| 62 |
+
# to add more entries to the vector store
|
| 63 |
+
# in case you want to start with a fresh vector store,
|
| 64 |
+
# you can pass in mode="overwrite" instead
|
| 65 |
+
|
| 66 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
| 67 |
+
texts=updated_vid_subs,
|
| 68 |
+
image_paths=vid_img_path,
|
| 69 |
+
embedding=embedder,
|
| 70 |
+
metadatas=vid_metadata,
|
| 71 |
+
connection=db,
|
| 72 |
+
table_name=TBL_NAME,
|
| 73 |
+
mode="overwrite",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def get_metadata_of_yt_video_with_captions(vid_url):
|
| 77 |
+
vid_filepath = download_video(vid_url, vid_dir)
|
| 78 |
+
vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
|
| 79 |
+
extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath) #should return lowercase file name without spaces
|
| 80 |
+
store_in_rag()
|
| 81 |
+
open_table()
|
| 82 |
+
return vid_filepath
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
def chat_response_llvm(instruction):
|
| 86 |
+
#file_path = the_metadatas[0]
|
| 87 |
+
file_path = 'shared_data/videos/yt_video/extracted_frame/'
|
| 88 |
+
result = ollama.generate(
|
| 89 |
+
model='llava',
|
| 90 |
+
prompt=instruction,
|
| 91 |
+
images=[file_path],
|
| 92 |
+
stream=True
|
| 93 |
+
)['response']
|
| 94 |
+
return result
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def return_top_k_most_similar_docs(query="show me a group of astronauts", max_docs=1):
|
| 98 |
+
# ask to return top 3 most similar documents
|
| 99 |
+
# Creating a LanceDB vector store
|
| 100 |
+
vectorstore = MultimodalLanceDB(
|
| 101 |
+
uri=LANCEDB_HOST_FILE,
|
| 102 |
+
embedding=embedder,
|
| 103 |
+
table_name=TBL_NAME)
|
| 104 |
+
|
| 105 |
+
# creating a retriever for the vector store
|
| 106 |
+
# search_type="similarity"
|
| 107 |
+
# declares that the type of search that the Retriever should perform
|
| 108 |
+
# is similarity search
|
| 109 |
+
# search_kwargs={"k": 1} means returning top-1 most similar document
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
retriever = vectorstore.as_retriever(
|
| 113 |
+
search_type='similarity',
|
| 114 |
+
search_kwargs={"k": max_docs})
|
| 115 |
+
|
| 116 |
+
results = retriever.invoke(query)
|
| 117 |
+
return results[0].page_content, Image.open(results[0].metadata['extracted_frame_path'])
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def process_url_and_init(youtube_url):
|
| 121 |
+
vid_filepath = get_metadata_of_yt_video_with_captions(youtube_url)
|
| 122 |
+
return vid_filepath
|
| 123 |
+
|
| 124 |
+
def init_ui():
|
| 125 |
+
with gr.Blocks() as demo:
|
| 126 |
+
url_input = gr.Textbox(label="Enter YouTube URL", value="https://www.youtube.com/watch?v=7Hcg-rLYwdM", interactive=False)
|
| 127 |
+
submit_btn = gr.Button("Process Video")
|
| 128 |
+
#vid_filepath = 'shared_data/videos/yt_video/Welcome_back_to_Planet_Earth.mp4'
|
| 129 |
+
chatbox = gr.Textbox(label="What question do you want to ask?", value="show me a group of astronauts")
|
| 130 |
+
response = gr.Textbox(label="Response", interactive=False)
|
| 131 |
+
video = gr.Video()
|
| 132 |
+
frame = gr.Image()
|
| 133 |
+
submit_btn2 = gr.Button("ASK")
|
| 134 |
+
|
| 135 |
+
submit_btn.click(fn=process_url_and_init, inputs=url_input, outputs=[video])
|
| 136 |
+
submit_btn2.click(fn=return_top_k_most_similar_docs, inputs=[chatbox], outputs=[response, frame])
|
| 137 |
+
return demo
|
| 138 |
+
|
| 139 |
+
if __name__ == '__main__':
|
| 140 |
+
demo = init_ui()
|
| 141 |
+
demo.launch(True)
|
| 142 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
langchain-predictionguard
|
| 3 |
+
IPython
|
| 4 |
+
umap-learn
|
| 5 |
+
pytubefix
|
| 6 |
+
youtube_transcript_api
|
| 7 |
+
torch
|
| 8 |
+
transformers
|
| 9 |
+
matplotlib
|
| 10 |
+
seaborn
|
| 11 |
+
datasets
|
| 12 |
+
moviepy
|
| 13 |
+
whisper
|
| 14 |
+
webvtt-py
|
| 15 |
+
tqdm
|
| 16 |
+
lancedb
|
| 17 |
+
langchain-core
|
| 18 |
+
langchain-community
|
| 19 |
+
ollama
|
| 20 |
+
opencv-python
|
| 21 |
+
openai-whisper
|
| 22 |
+
huggingface_hub[cli]
|
s2_download_data.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from IPython.display import display
|
| 4 |
+
|
| 5 |
+
# You can use your own uploaded images and captions.
|
| 6 |
+
# You will be responsible for the legal use of images that
|
| 7 |
+
# you are going to use.
|
| 8 |
+
|
| 9 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
| 10 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
| 11 |
+
|
| 12 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
| 13 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
| 14 |
+
|
| 15 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
| 16 |
+
cap3='a cat laying down stretched out near a laptop'
|
| 17 |
+
|
| 18 |
+
img1 = {
|
| 19 |
+
'flickr_url': url1,
|
| 20 |
+
'caption': cap1,
|
| 21 |
+
'image_path' : './shared_data/motorcycle_1.jpg'
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
img2 = {
|
| 25 |
+
'flickr_url': url2,
|
| 26 |
+
'caption': cap2,
|
| 27 |
+
'image_path' : './shared_data/motorcycle_2.jpg'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
img3 = {
|
| 31 |
+
'flickr_url' : url3,
|
| 32 |
+
'caption': cap3,
|
| 33 |
+
'image_path' : './shared_data/cat_1.jpg'
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def download_images():
|
| 37 |
+
# download images
|
| 38 |
+
imgs = [img1, img2, img3]
|
| 39 |
+
for img in imgs:
|
| 40 |
+
data = requests.get(img['flickr_url']).content
|
| 41 |
+
with open(img['image_path'], 'wb') as f:
|
| 42 |
+
f.write(data)
|
| 43 |
+
|
| 44 |
+
for img in [img1, img2, img3]:
|
| 45 |
+
image = Image.open(img['image_path'])
|
| 46 |
+
caption = img['caption']
|
| 47 |
+
display(image)
|
| 48 |
+
print(caption)
|
| 49 |
+
|
s3_data_to_vector_embedding.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numpy.linalg import norm
|
| 2 |
+
from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
| 8 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
| 9 |
+
|
| 10 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
| 11 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
| 12 |
+
|
| 13 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
| 14 |
+
cap3='a cat laying down stretched out near a laptop'
|
| 15 |
+
|
| 16 |
+
img1 = {
|
| 17 |
+
'flickr_url': url1,
|
| 18 |
+
'caption': cap1,
|
| 19 |
+
'image_path' : './shared_data/motorcycle_1.jpg',
|
| 20 |
+
'tensor_path' : './shared_data/motorcycle_1'
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
img2 = {
|
| 24 |
+
'flickr_url': url2,
|
| 25 |
+
'caption': cap2,
|
| 26 |
+
'image_path' : './shared_data/motorcycle_2.jpg',
|
| 27 |
+
'tensor_path' : './shared_data/motorcycle_2'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
img3 = {
|
| 31 |
+
'flickr_url' : url3,
|
| 32 |
+
'caption': cap3,
|
| 33 |
+
'image_path' : './shared_data/cat_1.jpg',
|
| 34 |
+
'tensor_path' : './shared_data/cat_1'
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def bt_embeddings_from_local(text, image):
|
| 38 |
+
|
| 39 |
+
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
| 40 |
+
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
| 41 |
+
|
| 42 |
+
processed_inputs = processor(image, text, padding=True, return_tensors="pt")
|
| 43 |
+
|
| 44 |
+
outputs = model(**processed_inputs)
|
| 45 |
+
|
| 46 |
+
cross_modal_embeddings = outputs.cross_embeds
|
| 47 |
+
text_embeddings = outputs.text_embeds
|
| 48 |
+
image_embeddings = outputs.image_embeds
|
| 49 |
+
return {
|
| 50 |
+
'cross_modal_embeddings': cross_modal_embeddings,
|
| 51 |
+
'text_embeddings': text_embeddings,
|
| 52 |
+
'image_embeddings': image_embeddings
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def save_embeddings():
|
| 56 |
+
for img in [img1, img2, img3]:
|
| 57 |
+
embedding = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
|
| 58 |
+
print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
|
| 59 |
+
torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')
|
| 60 |
+
|
| 61 |
+
|
s4_calculate_distance.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from numpy.linalg import norm
|
| 3 |
+
import torch
|
| 4 |
+
from IPython.display import display
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
|
| 8 |
+
cap1='A motorcycle sits parked across from a herd of livestock'
|
| 9 |
+
|
| 10 |
+
url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
|
| 11 |
+
cap2='Motorcycle on platform to be worked on in garage'
|
| 12 |
+
|
| 13 |
+
url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
|
| 14 |
+
cap3='a cat laying down stretched out near a laptop'
|
| 15 |
+
|
| 16 |
+
img1 = {
|
| 17 |
+
'flickr_url': url1,
|
| 18 |
+
'caption': cap1,
|
| 19 |
+
'image_path' : './shared_data/motorcycle_1.jpg',
|
| 20 |
+
'tensor_path' : './shared_data/motorcycle_1'
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
img2 = {
|
| 24 |
+
'flickr_url': url2,
|
| 25 |
+
'caption': cap2,
|
| 26 |
+
'image_path' : './shared_data/motorcycle_2.jpg',
|
| 27 |
+
'tensor_path' : './shared_data/motorcycle_2'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
img3 = {
|
| 31 |
+
'flickr_url' : url3,
|
| 32 |
+
'caption': cap3,
|
| 33 |
+
'image_path' : './shared_data/cat_1.jpg',
|
| 34 |
+
'tensor_path' : './shared_data/cat_1'
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def load_tensor(path):
|
| 38 |
+
return torch.load(path)
|
| 39 |
+
|
| 40 |
+
def load_embeddings():
|
| 41 |
+
ex1_embed = load_tensor(img1['tensor_path'] + '.pt')
|
| 42 |
+
ex2_embed = load_tensor(img2['tensor_path'] + '.pt')
|
| 43 |
+
ex3_embed = load_tensor(img3['tensor_path'] + '.pt')
|
| 44 |
+
return ex1_embed.data.numpy(), ex2_embed.data.numpy(), ex3_embed.data.numpy()
|
| 45 |
+
|
| 46 |
+
def cosine_similarity(vec1, vec2):
|
| 47 |
+
similarity = np.dot(vec1,vec2)/(norm(vec1)*norm(vec2))
|
| 48 |
+
return similarity
|
| 49 |
+
|
| 50 |
+
def calculate_cosine_distance():
|
| 51 |
+
ex1_embed, ex2_embed, ex3_embed = load_embeddings()
|
| 52 |
+
similarity1 = cosine_similarity(ex1_embed, ex2_embed)
|
| 53 |
+
similarity2 = cosine_similarity(ex1_embed, ex3_embed)
|
| 54 |
+
similarity3 = cosine_similarity(ex2_embed, ex3_embed)
|
| 55 |
+
return [similarity1, similarity2, similarity3]
|
| 56 |
+
|
| 57 |
+
def calcuate_euclidean_distance():
|
| 58 |
+
ex1_embed, ex2_embed, ex3_embed = load_embeddings()
|
| 59 |
+
distance1 = cv2.norm(ex1_embed,ex2_embed, cv2.NORM_L2)
|
| 60 |
+
distance2 = cv2.norm(ex1_embed,ex3_embed, cv2.NORM_L2)
|
| 61 |
+
distance3 = cv2.norm(ex2_embed,ex3_embed, cv2.NORM_L2)
|
| 62 |
+
return [distance1, distance2, distance3]
|
| 63 |
+
|
| 64 |
+
def show_cosine_distance():
|
| 65 |
+
distances = calculate_cosine_distance()
|
| 66 |
+
print("Cosine similarity between ex1_embeded and ex2_embeded is:")
|
| 67 |
+
display(distances[0])
|
| 68 |
+
print("Cosine similarity between ex1_embeded and ex3_embeded is:")
|
| 69 |
+
display(distances[1])
|
| 70 |
+
print("Cosine similarity between ex2_embeded and ex2_embeded is:")
|
| 71 |
+
display(distances[2])
|
| 72 |
+
|
| 73 |
+
def show_euclidean_distance():
|
| 74 |
+
distances = calcuate_euclidean_distance()
|
| 75 |
+
print("Euclidean distance between ex1_embeded and ex2_embeded is:")
|
| 76 |
+
display(distances[0])
|
| 77 |
+
print("Euclidean distance between ex1_embeded and ex3_embeded is:")
|
| 78 |
+
display(distances[1])
|
| 79 |
+
print("Euclidean distance between ex2_embeded and ex2_embeded is:")
|
| 80 |
+
display(distances[2])
|
| 81 |
+
|
| 82 |
+
show_cosine_distance()
|
| 83 |
+
show_euclidean_distance()
|
s5-how-to-umap.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path
|
| 2 |
+
from IPython.display import display
|
| 3 |
+
from umap import UMAP
|
| 4 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
from s3_data_to_vector_embedding import bt_embeddings_from_local
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
|
| 16 |
+
# prompt templates
|
| 17 |
+
templates = [
|
| 18 |
+
'a picture of {}',
|
| 19 |
+
'an image of {}',
|
| 20 |
+
'a nice {}',
|
| 21 |
+
'a beautiful {}',
|
| 22 |
+
]
|
| 23 |
+
# function helps to prepare list image-text pairs from the first [test_size] data
|
| 24 |
+
def data_prep(hf_dataset_name, templates=templates, test_size=1000):
|
| 25 |
+
# load Huggingface dataset by streaming the dataset which doesn’t download anything, and lets you use it instantly
|
| 26 |
+
#dataset = load_dataset(hf_dataset_name, trust_remote_code=True, split='train', streaming=True)
|
| 27 |
+
|
| 28 |
+
dataset = load_dataset(hf_dataset_name)
|
| 29 |
+
# split dataset with specific test_size
|
| 30 |
+
train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
|
| 31 |
+
test_dataset = train_test_dataset['test']
|
| 32 |
+
print(test_dataset)
|
| 33 |
+
# get the test dataset
|
| 34 |
+
img_txt_pairs = []
|
| 35 |
+
for i in range(len(test_dataset)):
|
| 36 |
+
img_txt_pairs.append({
|
| 37 |
+
'caption' : templates[random.randint(0, len(templates)-1)],
|
| 38 |
+
'pil_img' : test_dataset[i]['image']
|
| 39 |
+
})
|
| 40 |
+
return img_txt_pairs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_all_dataset():
|
| 45 |
+
|
| 46 |
+
car_img_txt_pairs = data_prep("tanganke/stanford_cars", test_size=50)
|
| 47 |
+
cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset", test_size=50)
|
| 48 |
+
|
| 49 |
+
return cat_img_txt_pairs, car_img_txt_pairs
|
| 50 |
+
# compute BridgeTower embeddings for cat image-text pairs
|
| 51 |
+
def load_cat_and_car_embeddings():
|
| 52 |
+
# prepare image_text pairs
|
| 53 |
+
cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
|
| 54 |
+
def save_embeddings(embedding, path):
|
| 55 |
+
torch.save(embedding, path)
|
| 56 |
+
|
| 57 |
+
def load_embeddings(img_txt_pair):
|
| 58 |
+
pil_img = img_txt_pair['pil_img']
|
| 59 |
+
caption = img_txt_pair['caption']
|
| 60 |
+
return bt_embeddings_from_local(caption, pil_img)
|
| 61 |
+
|
| 62 |
+
def load_all_embeddings_from_image_text_pairs(img_txt_pairs, file_name):
|
| 63 |
+
embeddings = []
|
| 64 |
+
for img_txt_pair in tqdm(
|
| 65 |
+
img_txt_pairs,
|
| 66 |
+
total=len(img_txt_pairs)
|
| 67 |
+
):
|
| 68 |
+
|
| 69 |
+
embedding = load_embeddings(img_txt_pair)
|
| 70 |
+
print(embedding)
|
| 71 |
+
cross_modal_embeddings = embedding['cross_modal_embeddings'][0].detach().numpy() #this is not the right way to convert tensor to numpy
|
| 72 |
+
#print(cross_modal_embeddings.shape) #<class 'torch.Tensor'>
|
| 73 |
+
#save_embeddings(cross_modal_embeddings, file_name)
|
| 74 |
+
embeddings.append(cross_modal_embeddings)
|
| 75 |
+
return cross_modal_embeddings
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
cat_embeddings = load_all_embeddings_from_image_text_pairs(cat_img_txt_pairs, './shared_data/cat_embeddings.pt')
|
| 79 |
+
car_embeddings = load_all_embeddings_from_image_text_pairs(car_img_txt_pairs, './shared_data/car_embeddings.pt')
|
| 80 |
+
|
| 81 |
+
return cat_embeddings, car_embeddings
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# function transforms high-dimension vectors to 2D vectors using UMAP
|
| 85 |
+
def dimensionality_reduction(embeddings, labels):
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
print(embeddings)
|
| 89 |
+
X_scaled = MinMaxScaler().fit_transform(embeddings.reshape(-1, 1)) # This is not the right way to scale the data
|
| 90 |
+
mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled)
|
| 91 |
+
df_emb = pd.DataFrame(mapper.embedding_, columns=["X", "Y"])
|
| 92 |
+
df_emb["label"] = labels
|
| 93 |
+
print(df_emb)
|
| 94 |
+
return df_emb
|
| 95 |
+
|
| 96 |
+
def show_umap_visualization():
|
| 97 |
+
def reduce_dimensions():
|
| 98 |
+
cat_embeddings, car_embeddings = load_cat_and_car_embeddings()
|
| 99 |
+
# stacking embeddings of cat and car examples into one numpy array
|
| 100 |
+
all_embeddings = np.concatenate([cat_embeddings, car_embeddings]) # This is not the right way to scale the data
|
| 101 |
+
|
| 102 |
+
# prepare labels for the 3 examples
|
| 103 |
+
labels = ['cat'] * len(cat_embeddings) + ['car'] * len(car_embeddings)
|
| 104 |
+
|
| 105 |
+
# compute dimensionality reduction for the 3 examples
|
| 106 |
+
reduced_dim_emb = dimensionality_reduction(all_embeddings, labels)
|
| 107 |
+
return reduced_dim_emb
|
| 108 |
+
|
| 109 |
+
reduced_dim_emb = reduce_dimensions()
|
| 110 |
+
# Plot the centroids against the cluster
|
| 111 |
+
fig, ax = plt.subplots(figsize=(8,6)) # Set figsize
|
| 112 |
+
|
| 113 |
+
sns.set_style("whitegrid", {'axes.grid' : False})
|
| 114 |
+
sns.scatterplot(data=reduced_dim_emb,
|
| 115 |
+
x=reduced_dim_emb['X'],
|
| 116 |
+
y=reduced_dim_emb['Y'],
|
| 117 |
+
hue='label',
|
| 118 |
+
palette='bright')
|
| 119 |
+
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
|
| 120 |
+
plt.title('Scatter plot of images of cats and cars using UMAP')
|
| 121 |
+
plt.xlabel('X')
|
| 122 |
+
plt.ylabel('Y')
|
| 123 |
+
plt.show()
|
| 124 |
+
|
| 125 |
+
def an_example_of_cat_and_car_pair_data():
|
| 126 |
+
cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
|
| 127 |
+
# display an example of a cat image-text pair data
|
| 128 |
+
display(cat_img_txt_pairs[0]['caption'])
|
| 129 |
+
display(cat_img_txt_pairs[0]['pil_img'])
|
| 130 |
+
|
| 131 |
+
# display an example of a car image-text pair data
|
| 132 |
+
display(car_img_txt_pairs[0]['caption'])
|
| 133 |
+
display(car_img_txt_pairs[0]['pil_img'])
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
show_umap_visualization()
|
s6_prepare_video_input.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
from os import path as osp
|
| 4 |
+
import whisper
|
| 5 |
+
from moviepy import VideoFileClip
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
|
| 8 |
+
from urllib.request import urlretrieve
|
| 9 |
+
from IPython.display import display
|
| 10 |
+
import ollama
|
| 11 |
+
|
| 12 |
+
def demp_video_input_that_has_transcript():
|
| 13 |
+
# first video's url
|
| 14 |
+
vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
|
| 15 |
+
|
| 16 |
+
# download Youtube video to ./shared_data/videos/video1
|
| 17 |
+
vid_dir = "./shared_data/videos/video1"
|
| 18 |
+
vid_filepath = download_video(vid_url, vid_dir)
|
| 19 |
+
|
| 20 |
+
# download Youtube video's subtitle to ./shared_data/videos/video1
|
| 21 |
+
vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
|
| 22 |
+
|
| 23 |
+
return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
|
| 24 |
+
|
| 25 |
+
def demp_video_input_that_has_no_transcript():
|
| 26 |
+
# second video's url
|
| 27 |
+
vid_url=(
|
| 28 |
+
"https://multimedia-commons.s3-us-west-2.amazonaws.com/"
|
| 29 |
+
"data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
|
| 30 |
+
)
|
| 31 |
+
vid_dir = "./shared_data/videos/video2"
|
| 32 |
+
vid_name = "toddler_in_playground.mp4"
|
| 33 |
+
|
| 34 |
+
# create folder to which video2 will be downloaded
|
| 35 |
+
Path(vid_dir).mkdir(parents=True, exist_ok=True)
|
| 36 |
+
vid_filepath = urlretrieve(
|
| 37 |
+
vid_url,
|
| 38 |
+
osp.join(vid_dir, vid_name)
|
| 39 |
+
)[0]
|
| 40 |
+
|
| 41 |
+
path_to_video_no_transcript = vid_filepath
|
| 42 |
+
|
| 43 |
+
# declare where to save .mp3 audio
|
| 44 |
+
path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
|
| 45 |
+
|
| 46 |
+
# extract mp3 audio file from mp4 video video file
|
| 47 |
+
clip = VideoFileClip(path_to_video_no_transcript)
|
| 48 |
+
clip.audio.write_audiofile(path_to_extracted_audio_file)
|
| 49 |
+
|
| 50 |
+
model = whisper.load_model("small")
|
| 51 |
+
options = dict(task="translate", best_of=1, language='en')
|
| 52 |
+
results = model.transcribe(path_to_extracted_audio_file, **options)
|
| 53 |
+
|
| 54 |
+
vtt = getSubs(results["segments"], "vtt")
|
| 55 |
+
|
| 56 |
+
# path to save generated transcript of video1
|
| 57 |
+
path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
|
| 58 |
+
# write transcription to file
|
| 59 |
+
with open(path_to_generated_trans, 'w') as f:
|
| 60 |
+
f.write(vtt)
|
| 61 |
+
|
| 62 |
+
return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def ask_llvm(instruction, file_path):
|
| 67 |
+
result = ollama.generate(
|
| 68 |
+
model='llava',
|
| 69 |
+
prompt=instruction,
|
| 70 |
+
images=[file_path],
|
| 71 |
+
stream=False
|
| 72 |
+
)['response']
|
| 73 |
+
img=Image.open(file_path, mode='r')
|
| 74 |
+
img = img.resize([int(i/1.2) for i in img.size])
|
| 75 |
+
display(img)
|
| 76 |
+
for i in result.split('.'):
|
| 77 |
+
print(i, end='', flush=True)
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
meta_data = demp_video_input_that_has_transcript()
|
| 80 |
+
|
| 81 |
+
meta_data1 = demp_video_input_that_has_no_transcript()
|
| 82 |
+
data = meta_data1[1]
|
| 83 |
+
caption = data['transcript']
|
| 84 |
+
print(f'Generated caption is: "{caption}"')
|
| 85 |
+
frame = Image.open(data['extracted_frame_path'])
|
| 86 |
+
display(frame)
|
| 87 |
+
instruction = "Can you describe the image?"
|
| 88 |
+
ask_llvm(instruction, data['extracted_frame_path'])
|
| 89 |
+
#print(meta_data)
|
| 90 |
+
|
s7_store_in_rag.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mm_rag.embeddings.bridgetower_embeddings import (
|
| 2 |
+
BridgeTowerEmbeddings
|
| 3 |
+
)
|
| 4 |
+
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
| 5 |
+
import lancedb
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from utility import load_json_file, display_retrieved_results
|
| 10 |
+
import pyarrow as pa
|
| 11 |
+
|
| 12 |
+
# declare host file
|
| 13 |
+
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
| 14 |
+
# declare table name
|
| 15 |
+
TBL_NAME = "test_tbl"
|
| 16 |
+
# initialize vectorstore
|
| 17 |
+
db = lancedb.connect(LANCEDB_HOST_FILE)
|
| 18 |
+
# initialize an BridgeTower embedder
|
| 19 |
+
embedder = BridgeTowerEmbeddings()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def return_top_k_most_similar_docs(max_docs=3):
|
| 23 |
+
# ask to return top 3 most similar documents
|
| 24 |
+
# Creating a LanceDB vector store
|
| 25 |
+
vectorstore = MultimodalLanceDB(
|
| 26 |
+
uri=LANCEDB_HOST_FILE,
|
| 27 |
+
embedding=embedder,
|
| 28 |
+
table_name=TBL_NAME)
|
| 29 |
+
|
| 30 |
+
# creating a retriever for the vector store
|
| 31 |
+
# search_type="similarity"
|
| 32 |
+
# declares that the type of search that the Retriever should perform
|
| 33 |
+
# is similarity search
|
| 34 |
+
# search_kwargs={"k": 1} means returning top-1 most similar document
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
retriever = vectorstore.as_retriever(
|
| 38 |
+
search_type='similarity',
|
| 39 |
+
search_kwargs={"k": max_docs})
|
| 40 |
+
query2 = (
|
| 41 |
+
"an astronaut's spacewalk "
|
| 42 |
+
"with an amazing view of the earth from space behind"
|
| 43 |
+
)
|
| 44 |
+
results2 = retriever.invoke(query2)
|
| 45 |
+
display_retrieved_results(results2)
|
| 46 |
+
query3 = "a group of astronauts"
|
| 47 |
+
results3 = retriever.invoke(query3)
|
| 48 |
+
display_retrieved_results(results3)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def open_table(TBL_NAME):
|
| 52 |
+
# open a connection to table TBL_NAME
|
| 53 |
+
tbl = db.open_table()
|
| 54 |
+
|
| 55 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
| 56 |
+
# display the first 3 rows of the table
|
| 57 |
+
tbl.to_pandas()[['text', 'image_path']].head(3)
|
| 58 |
+
|
| 59 |
+
def store_in_rag():
|
| 60 |
+
|
| 61 |
+
# load metadata files
|
| 62 |
+
vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
|
| 63 |
+
vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
|
| 64 |
+
vid1_metadata = load_json_file(vid1_metadata_path)
|
| 65 |
+
vid2_metadata = load_json_file(vid2_metadata_path)
|
| 66 |
+
|
| 67 |
+
# collect transcripts and image paths
|
| 68 |
+
vid1_trans = [vid['transcript'] for vid in vid1_metadata]
|
| 69 |
+
vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
|
| 70 |
+
|
| 71 |
+
vid2_trans = [vid['transcript'] for vid in vid2_metadata]
|
| 72 |
+
vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# for video1, we pick n = 7
|
| 76 |
+
n = 7
|
| 77 |
+
updated_vid1_trans = [
|
| 78 |
+
' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
|
| 79 |
+
' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
# also need to update the updated transcripts in metadata
|
| 83 |
+
for i in range(len(updated_vid1_trans)):
|
| 84 |
+
vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
|
| 85 |
+
|
| 86 |
+
# you can pass in mode="append"
|
| 87 |
+
# to add more entries to the vector store
|
| 88 |
+
# in case you want to start with a fresh vector store,
|
| 89 |
+
# you can pass in mode="overwrite" instead
|
| 90 |
+
|
| 91 |
+
_ = MultimodalLanceDB.from_text_image_pairs(
|
| 92 |
+
texts=updated_vid1_trans+vid2_trans,
|
| 93 |
+
image_paths=vid1_img_path+vid2_img_path,
|
| 94 |
+
embedding=embedder,
|
| 95 |
+
metadatas=vid1_metadata+vid2_metadata,
|
| 96 |
+
connection=db,
|
| 97 |
+
table_name=TBL_NAME,
|
| 98 |
+
mode="overwrite",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
tbl = db.open_table(TBL_NAME)
|
| 103 |
+
print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
|
| 104 |
+
#display the first 3 rows of the table
|
| 105 |
+
return_top_k_most_similar_docs()
|
upload_huggingface.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
api = HfApi()
|
| 3 |
+
api.upload_large_folder(
|
| 4 |
+
repo_id="88hours/hg_demo",
|
| 5 |
+
repo_type="space",
|
| 6 |
+
folder_path="./",
|
| 7 |
+
|
| 8 |
+
)
|
utility.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add your utilities or helper functions to this file.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from dotenv import load_dotenv, find_dotenv
|
| 6 |
+
from io import StringIO, BytesIO
|
| 7 |
+
import textwrap
|
| 8 |
+
from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
|
| 9 |
+
from enum import auto, Enum
|
| 10 |
+
import base64
|
| 11 |
+
import glob
|
| 12 |
+
import requests
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from pytubefix import YouTube, Stream
|
| 15 |
+
import webvtt
|
| 16 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
| 17 |
+
from youtube_transcript_api.formatters import WebVTTFormatter
|
| 18 |
+
from predictionguard import PredictionGuard
|
| 19 |
+
import cv2
|
| 20 |
+
import json
|
| 21 |
+
import PIL
|
| 22 |
+
from ollama import chat
|
| 23 |
+
from ollama import ChatResponse
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import dataclasses
|
| 26 |
+
import random
|
| 27 |
+
from datasets import load_dataset
|
| 28 |
+
from os import path as osp
|
| 29 |
+
from IPython.display import display
|
| 30 |
+
from langchain_core.prompt_values import PromptValue
|
| 31 |
+
from langchain_core.messages import (
|
| 32 |
+
MessageLikeRepresentation,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
|
| 36 |
+
|
| 37 |
+
def get_from_dict_or_env(
|
| 38 |
+
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
| 39 |
+
) -> str:
|
| 40 |
+
"""Get a value from a dictionary or an environment variable."""
|
| 41 |
+
if key in data and data[key]:
|
| 42 |
+
return data[key]
|
| 43 |
+
else:
|
| 44 |
+
return get_from_env(key, env_key, default=default)
|
| 45 |
+
|
| 46 |
+
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
| 47 |
+
"""Get a value from a dictionary or an environment variable."""
|
| 48 |
+
if env_key in os.environ and os.environ[env_key]:
|
| 49 |
+
return os.environ[env_key]
|
| 50 |
+
else:
|
| 51 |
+
return default
|
| 52 |
+
|
| 53 |
+
def load_env():
|
| 54 |
+
_ = load_dotenv(find_dotenv())
|
| 55 |
+
|
| 56 |
+
def get_openai_api_key():
|
| 57 |
+
load_env()
|
| 58 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 59 |
+
return openai_api_key
|
| 60 |
+
|
| 61 |
+
def get_prediction_guard_api_key():
|
| 62 |
+
load_env()
|
| 63 |
+
PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
|
| 64 |
+
if PREDICTION_GUARD_API_KEY is None:
|
| 65 |
+
PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
|
| 66 |
+
return PREDICTION_GUARD_API_KEY
|
| 67 |
+
|
| 68 |
+
PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
|
| 69 |
+
|
| 70 |
+
# prompt templates
|
| 71 |
+
templates = [
|
| 72 |
+
'a picture of {}',
|
| 73 |
+
'an image of {}',
|
| 74 |
+
'a nice {}',
|
| 75 |
+
'a beautiful {}',
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
|
| 79 |
+
def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
|
| 80 |
+
# load Huggingface dataset (download if needed)
|
| 81 |
+
dataset = load_dataset(hf_dataset, trust_remote_code=True)
|
| 82 |
+
# split dataset with specific test_size
|
| 83 |
+
train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
|
| 84 |
+
# get the test dataset
|
| 85 |
+
test_dataset = train_test_dataset['test']
|
| 86 |
+
img_txt_pairs = []
|
| 87 |
+
for i in range(len(test_dataset)):
|
| 88 |
+
img_txt_pairs.append({
|
| 89 |
+
'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
|
| 90 |
+
'pil_img' : test_dataset[i]['image']
|
| 91 |
+
})
|
| 92 |
+
return img_txt_pairs
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def download_video(video_url, path='/tmp/'):
|
| 96 |
+
print(f'Getting video information for {video_url}')
|
| 97 |
+
if not video_url.startswith('http'):
|
| 98 |
+
return os.path.join(path, video_url)
|
| 99 |
+
|
| 100 |
+
filepath = glob.glob(os.path.join(path, '*.mp4'))
|
| 101 |
+
if len(filepath) > 0:
|
| 102 |
+
print('Video already downloaded')
|
| 103 |
+
return filepath[0]
|
| 104 |
+
|
| 105 |
+
def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
|
| 106 |
+
pbar.update(len(data_chunk))
|
| 107 |
+
|
| 108 |
+
yt = YouTube(video_url, on_progress_callback=progress_callback)
|
| 109 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
|
| 110 |
+
if stream is None:
|
| 111 |
+
stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
| 112 |
+
if not os.path.exists(path):
|
| 113 |
+
os.makedirs(path)
|
| 114 |
+
filename = stream.default_filename.replace(' ', '_')
|
| 115 |
+
filepath = os.path.join(path, filename)
|
| 116 |
+
|
| 117 |
+
if not os.path.exists(filepath):
|
| 118 |
+
print('Downloading video from YouTube...')
|
| 119 |
+
pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
|
| 120 |
+
stream.download(path, filename=filename)
|
| 121 |
+
pbar.close()
|
| 122 |
+
return filepath
|
| 123 |
+
|
| 124 |
+
def get_video_id_from_url(video_url):
|
| 125 |
+
"""
|
| 126 |
+
Examples:
|
| 127 |
+
- http://youtu.be/SA2iWivDJiE
|
| 128 |
+
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
|
| 129 |
+
- http://www.youtube.com/embed/SA2iWivDJiE
|
| 130 |
+
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
|
| 131 |
+
"""
|
| 132 |
+
import urllib.parse
|
| 133 |
+
url = urllib.parse.urlparse(video_url)
|
| 134 |
+
if url.hostname == 'youtu.be':
|
| 135 |
+
return url.path[1:]
|
| 136 |
+
if url.hostname in ('www.youtube.com', 'youtube.com'):
|
| 137 |
+
if url.path == '/watch':
|
| 138 |
+
p = urllib.parse.parse_qs(url.query)
|
| 139 |
+
return p['v'][0]
|
| 140 |
+
if url.path[:7] == '/embed/':
|
| 141 |
+
return url.path.split('/')[2]
|
| 142 |
+
if url.path[:3] == '/v/':
|
| 143 |
+
return url.path.split('/')[2]
|
| 144 |
+
|
| 145 |
+
return video_url
|
| 146 |
+
|
| 147 |
+
# if this has transcript then download
|
| 148 |
+
def get_transcript_vtt(video_url, path='/tmp'):
|
| 149 |
+
video_id = get_video_id_from_url(video_url)
|
| 150 |
+
filepath = os.path.join(path,'captions.vtt')
|
| 151 |
+
if os.path.exists(filepath):
|
| 152 |
+
return filepath
|
| 153 |
+
|
| 154 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
|
| 155 |
+
formatter = WebVTTFormatter()
|
| 156 |
+
webvtt_formatted = formatter.format_transcript(transcript)
|
| 157 |
+
|
| 158 |
+
with open(filepath, 'w', encoding='utf-8') as webvtt_file:
|
| 159 |
+
webvtt_file.write(webvtt_formatted)
|
| 160 |
+
webvtt_file.close()
|
| 161 |
+
|
| 162 |
+
return filepath
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# helper function for convert time in second to time format for .vtt or .srt file
|
| 166 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
| 167 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
| 168 |
+
milliseconds = round(seconds * 1000.0)
|
| 169 |
+
|
| 170 |
+
hours = milliseconds // 3_600_000
|
| 171 |
+
milliseconds -= hours * 3_600_000
|
| 172 |
+
|
| 173 |
+
minutes = milliseconds // 60_000
|
| 174 |
+
milliseconds -= minutes * 60_000
|
| 175 |
+
|
| 176 |
+
seconds = milliseconds // 1_000
|
| 177 |
+
milliseconds -= seconds * 1_000
|
| 178 |
+
|
| 179 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 180 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
| 181 |
+
|
| 182 |
+
# a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
|
| 183 |
+
def str2time(strtime):
|
| 184 |
+
# strip character " if exists
|
| 185 |
+
strtime = strtime.strip('"')
|
| 186 |
+
# get hour, minute, second from time string
|
| 187 |
+
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
|
| 188 |
+
# get the corresponding time as total seconds
|
| 189 |
+
total_seconds = hrs * 60**2 + mins * 60 + seconds
|
| 190 |
+
total_miliseconds = total_seconds * 1000
|
| 191 |
+
return total_miliseconds
|
| 192 |
+
|
| 193 |
+
def _processText(text: str, maxLineWidth=None):
|
| 194 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
| 195 |
+
return text
|
| 196 |
+
|
| 197 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
| 198 |
+
return '\n'.join(lines)
|
| 199 |
+
|
| 200 |
+
# Resizes a image and maintains aspect ratio
|
| 201 |
+
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
| 202 |
+
# Grab the image size and initialize dimensions
|
| 203 |
+
dim = None
|
| 204 |
+
(h, w) = image.shape[:2]
|
| 205 |
+
|
| 206 |
+
# Return original image if no need to resize
|
| 207 |
+
if width is None and height is None:
|
| 208 |
+
return image
|
| 209 |
+
|
| 210 |
+
# We are resizing height if width is none
|
| 211 |
+
if width is None:
|
| 212 |
+
# Calculate the ratio of the height and construct the dimensions
|
| 213 |
+
r = height / float(h)
|
| 214 |
+
dim = (int(w * r), height)
|
| 215 |
+
# We are resizing width if height is none
|
| 216 |
+
else:
|
| 217 |
+
# Calculate the ratio of the width and construct the dimensions
|
| 218 |
+
r = width / float(w)
|
| 219 |
+
dim = (width, int(h * r))
|
| 220 |
+
|
| 221 |
+
# Return the resized image
|
| 222 |
+
return cv2.resize(image, dim, interpolation=inter)
|
| 223 |
+
|
| 224 |
+
# helper function to convert transcripts generated by whisper to .vtt file
|
| 225 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
| 226 |
+
print("WEBVTT\n", file=file)
|
| 227 |
+
for segment in transcript:
|
| 228 |
+
text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
|
| 229 |
+
|
| 230 |
+
print(
|
| 231 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
| 232 |
+
f"{text}\n",
|
| 233 |
+
file=file,
|
| 234 |
+
flush=True,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# helper function to convert transcripts generated by whisper to .srt file
|
| 238 |
+
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
| 239 |
+
"""
|
| 240 |
+
Write a transcript to a file in SRT format.
|
| 241 |
+
Example usage:
|
| 242 |
+
from pathlib import Path
|
| 243 |
+
from whisper.utils import write_srt
|
| 244 |
+
import requests
|
| 245 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
| 246 |
+
# save SRT
|
| 247 |
+
audio_basename = Path(audio_path).stem
|
| 248 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
| 249 |
+
write_srt(result["segments"], file=srt)
|
| 250 |
+
"""
|
| 251 |
+
for i, segment in enumerate(transcript, start=1):
|
| 252 |
+
text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
|
| 253 |
+
|
| 254 |
+
# write srt lines
|
| 255 |
+
print(
|
| 256 |
+
f"{i}\n"
|
| 257 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
| 258 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
| 259 |
+
f"{text}\n",
|
| 260 |
+
file=file,
|
| 261 |
+
flush=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
|
| 265 |
+
segmentStream = StringIO()
|
| 266 |
+
|
| 267 |
+
if format == 'vtt':
|
| 268 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
| 269 |
+
elif format == 'srt':
|
| 270 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
| 271 |
+
else:
|
| 272 |
+
raise Exception("Unknown format " + format)
|
| 273 |
+
|
| 274 |
+
segmentStream.seek(0)
|
| 275 |
+
return segmentStream.read()
|
| 276 |
+
|
| 277 |
+
# encoding image at given path or PIL Image using base64
|
| 278 |
+
def encode_image(image_path_or_PIL_img):
|
| 279 |
+
if isinstance(image_path_or_PIL_img, PIL.Image.Image):
|
| 280 |
+
# this is a PIL image
|
| 281 |
+
buffered = BytesIO()
|
| 282 |
+
image_path_or_PIL_img.save(buffered, format="JPEG")
|
| 283 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 284 |
+
else:
|
| 285 |
+
# this is a image_path
|
| 286 |
+
with open(image_path_or_PIL_img, "rb") as image_file:
|
| 287 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 288 |
+
|
| 289 |
+
# checking whether the given string is base64 or not
|
| 290 |
+
def isBase64(sb):
|
| 291 |
+
try:
|
| 292 |
+
if isinstance(sb, str):
|
| 293 |
+
# If there's any unicode here, an exception will be thrown and the function will return false
|
| 294 |
+
sb_bytes = bytes(sb, 'ascii')
|
| 295 |
+
elif isinstance(sb, bytes):
|
| 296 |
+
sb_bytes = sb
|
| 297 |
+
else:
|
| 298 |
+
raise ValueError("Argument must be string or bytes")
|
| 299 |
+
return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
|
| 300 |
+
except Exception:
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
def encode_image_from_path_or_url(image_path_or_url):
|
| 304 |
+
try:
|
| 305 |
+
# try to open the url to check valid url
|
| 306 |
+
f = urlopen(image_path_or_url)
|
| 307 |
+
# if this is an url
|
| 308 |
+
return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
|
| 309 |
+
except:
|
| 310 |
+
# this is a path to image
|
| 311 |
+
with open(image_path_or_url, "rb") as image_file:
|
| 312 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 313 |
+
|
| 314 |
+
# helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
|
| 315 |
+
def bt_embedding_from_prediction_guard(prompt, base64_image):
|
| 316 |
+
# get PredictionGuard client
|
| 317 |
+
client = _getPredictionGuardClient()
|
| 318 |
+
message = {"text": prompt,}
|
| 319 |
+
if base64_image is not None and base64_image != "":
|
| 320 |
+
if not isBase64(base64_image):
|
| 321 |
+
raise TypeError("image input must be in base64 encoding!")
|
| 322 |
+
message['image'] = base64_image
|
| 323 |
+
response = client.embeddings.create(
|
| 324 |
+
model="bridgetower-large-itm-mlm-itc",
|
| 325 |
+
input=[message]
|
| 326 |
+
)
|
| 327 |
+
return response['data'][0]['embedding']
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_json_file(file_path):
|
| 331 |
+
# Open the JSON file in read mode
|
| 332 |
+
with open(file_path, 'r') as file:
|
| 333 |
+
data = json.load(file)
|
| 334 |
+
return data
|
| 335 |
+
|
| 336 |
+
def display_retrieved_results(results):
|
| 337 |
+
print(f'There is/are {len(results)} retrieved result(s)')
|
| 338 |
+
print()
|
| 339 |
+
for i, res in enumerate(results):
|
| 340 |
+
print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
|
| 341 |
+
print()
|
| 342 |
+
print(results[i])
|
| 343 |
+
#display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
|
| 344 |
+
print("------------------------------------------------------------")
|
| 345 |
+
|
| 346 |
+
class SeparatorStyle(Enum):
|
| 347 |
+
"""Different separator style."""
|
| 348 |
+
SINGLE = auto()
|
| 349 |
+
|
| 350 |
+
@dataclasses.dataclass
|
| 351 |
+
class Conversation:
|
| 352 |
+
"""A class that keeps all conversation history"""
|
| 353 |
+
system: str
|
| 354 |
+
roles: List[str]
|
| 355 |
+
messages: List[List[str]]
|
| 356 |
+
map_roles: Dict[str, str]
|
| 357 |
+
version: str = "Unknown"
|
| 358 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 359 |
+
sep: str = "\n"
|
| 360 |
+
|
| 361 |
+
def _get_prompt_role(self, role):
|
| 362 |
+
if self.map_roles is not None and role in self.map_roles.keys():
|
| 363 |
+
return self.map_roles[role]
|
| 364 |
+
else:
|
| 365 |
+
return role
|
| 366 |
+
|
| 367 |
+
def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
|
| 368 |
+
content = []
|
| 369 |
+
if len(first_message) != 2:
|
| 370 |
+
raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
|
| 371 |
+
|
| 372 |
+
prompt, b64_image = first_message[0], first_message[1]
|
| 373 |
+
|
| 374 |
+
# handling prompt
|
| 375 |
+
if prompt is None:
|
| 376 |
+
raise TypeError("API does not support None prompt yet")
|
| 377 |
+
content.append({
|
| 378 |
+
"type": "text",
|
| 379 |
+
"text": prompt
|
| 380 |
+
})
|
| 381 |
+
if b64_image is None:
|
| 382 |
+
raise TypeError("API does not support text only conversation yet")
|
| 383 |
+
|
| 384 |
+
# handling image
|
| 385 |
+
if not isBase64(b64_image):
|
| 386 |
+
raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
|
| 387 |
+
|
| 388 |
+
content.append({
|
| 389 |
+
"type": "image_url",
|
| 390 |
+
"image_url": {
|
| 391 |
+
"url": b64_image,
|
| 392 |
+
}
|
| 393 |
+
})
|
| 394 |
+
return content
|
| 395 |
+
|
| 396 |
+
def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
|
| 397 |
+
|
| 398 |
+
if follow_up_message is not None and len(follow_up_message) > 1:
|
| 399 |
+
raise TypeError("Follow-up message in Conversation must not include an image!")
|
| 400 |
+
|
| 401 |
+
# handling text prompt
|
| 402 |
+
if follow_up_message is None or follow_up_message[0] is None:
|
| 403 |
+
raise TypeError("Follow-up message in Conversation must include exactly one text message")
|
| 404 |
+
|
| 405 |
+
text = follow_up_message[0]
|
| 406 |
+
return text
|
| 407 |
+
|
| 408 |
+
def get_message(self):
|
| 409 |
+
messages = self.messages
|
| 410 |
+
api_messages = []
|
| 411 |
+
for i, msg in enumerate(messages):
|
| 412 |
+
role, message_content = msg
|
| 413 |
+
if i == 0:
|
| 414 |
+
# get content for very first message in conversation
|
| 415 |
+
content = self._build_content_for_first_message_in_conversation(message_content)
|
| 416 |
+
else:
|
| 417 |
+
# get content for follow-up message in conversation
|
| 418 |
+
content = self._build_content_for_follow_up_messages_in_conversation(message_content)
|
| 419 |
+
|
| 420 |
+
api_messages.append({
|
| 421 |
+
"role": role,
|
| 422 |
+
"content": content,
|
| 423 |
+
})
|
| 424 |
+
return api_messages
|
| 425 |
+
|
| 426 |
+
# this method helps represent a multi-turn chat into as a single turn chat format
|
| 427 |
+
def serialize_messages(self):
|
| 428 |
+
messages = self.messages
|
| 429 |
+
ret = ""
|
| 430 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 431 |
+
if self.system is not None and self.system != "":
|
| 432 |
+
ret = self.system + self.sep
|
| 433 |
+
for i, (role, message) in enumerate(messages):
|
| 434 |
+
role = self._get_prompt_role(role)
|
| 435 |
+
if message:
|
| 436 |
+
if isinstance(message, List):
|
| 437 |
+
# get prompt only
|
| 438 |
+
message = message[0]
|
| 439 |
+
if i == 0:
|
| 440 |
+
# do not include role at the beginning
|
| 441 |
+
ret += message
|
| 442 |
+
else:
|
| 443 |
+
ret += role + ": " + message
|
| 444 |
+
if i < len(messages) - 1:
|
| 445 |
+
# avoid including sep at the end of serialized message
|
| 446 |
+
ret += self.sep
|
| 447 |
+
else:
|
| 448 |
+
ret += role + ":"
|
| 449 |
+
else:
|
| 450 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 451 |
+
|
| 452 |
+
return ret
|
| 453 |
+
|
| 454 |
+
def append_message(self, role, message):
|
| 455 |
+
if len(self.messages) == 0:
|
| 456 |
+
# data verification for the very first message
|
| 457 |
+
assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
|
| 458 |
+
assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
|
| 459 |
+
prompt, image = message[0], message[1]
|
| 460 |
+
assert prompt is not None, f"prompt must be not None"
|
| 461 |
+
assert isBase64(image), f"image must be under base64 encoding"
|
| 462 |
+
else:
|
| 463 |
+
# data verification for follow-up message
|
| 464 |
+
assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
|
| 465 |
+
assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
|
| 466 |
+
|
| 467 |
+
self.messages.append([role, message])
|
| 468 |
+
|
| 469 |
+
def copy(self):
|
| 470 |
+
return Conversation(
|
| 471 |
+
system=self.system,
|
| 472 |
+
roles=self.roles,
|
| 473 |
+
messages=[[x,y] for x, y in self.messages],
|
| 474 |
+
version=self.version,
|
| 475 |
+
map_roles=self.map_roles,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def dict(self):
|
| 479 |
+
return {
|
| 480 |
+
"system": self.system,
|
| 481 |
+
"roles": self.roles,
|
| 482 |
+
"messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
|
| 483 |
+
"version": self.version,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
prediction_guard_llava_conv = Conversation(
|
| 487 |
+
system="",
|
| 488 |
+
roles=("user", "assistant"),
|
| 489 |
+
messages=[],
|
| 490 |
+
version="Prediction Guard LLaVA enpoint Conversation v0",
|
| 491 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 492 |
+
map_roles={
|
| 493 |
+
"user": "USER",
|
| 494 |
+
"assistant": "ASSISTANT"
|
| 495 |
+
}
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# get PredictionGuard Client
|
| 499 |
+
def _getPredictionGuardClient():
|
| 500 |
+
PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
|
| 501 |
+
client = PredictionGuard(
|
| 502 |
+
api_key=PREDICTION_GUARD_API_KEY,
|
| 503 |
+
url=PREDICTION_GUARD_URL_ENDPOINT,
|
| 504 |
+
)
|
| 505 |
+
return client
|
| 506 |
+
|
| 507 |
+
# helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
|
| 508 |
+
def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
| 509 |
+
# prepare conversation
|
| 510 |
+
conversation = prediction_guard_llava_conv.copy()
|
| 511 |
+
conversation.append_message(conversation.roles[0], [prompt, image])
|
| 512 |
+
return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
| 517 |
+
# get PredictionGuard client
|
| 518 |
+
client = _getPredictionGuardClient()
|
| 519 |
+
# get message from conversation
|
| 520 |
+
messages = conversation.get_message()
|
| 521 |
+
# call chat completion endpoint at Grediction Guard
|
| 522 |
+
response = client.chat.completions.create(
|
| 523 |
+
model="llava-1.5-7b-hf",
|
| 524 |
+
messages=messages,
|
| 525 |
+
max_tokens=max_tokens,
|
| 526 |
+
temperature=temperature,
|
| 527 |
+
top_p=top_p,
|
| 528 |
+
top_k=top_k,
|
| 529 |
+
)
|
| 530 |
+
return response['choices'][-1]['message']['content']
|
| 531 |
+
|
| 532 |
+
def lvlm_inference_with_ollama(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# Send the request to the local Ollama server
|
| 537 |
+
#response = requests.post("http://localhost:8000/api/v1/completions", json=payload)
|
| 538 |
+
|
| 539 |
+
stream = chat(
|
| 540 |
+
model="llava-1.5-7b-hf",
|
| 541 |
+
messages= conversation,
|
| 542 |
+
stream=True,
|
| 543 |
+
temperature=temperature,
|
| 544 |
+
max_tokens=max_tokens,
|
| 545 |
+
top_p=top_p,
|
| 546 |
+
top_k=top_k
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
response_data = ''
|
| 550 |
+
for chunk in stream:
|
| 551 |
+
response_data += chunk['message']['content']
|
| 552 |
+
|
| 553 |
+
return response_data
|
| 554 |
+
|
| 555 |
+
# function `extract_and_save_frames_and_metadata``:
|
| 556 |
+
# receives as input a video and its transcript
|
| 557 |
+
# does extracting and saving frames and their metadatas
|
| 558 |
+
# returns the extracted metadatas
|
| 559 |
+
def extract_and_save_frames_and_metadata(
|
| 560 |
+
path_to_video,
|
| 561 |
+
path_to_transcript,
|
| 562 |
+
path_to_save_extracted_frames,
|
| 563 |
+
path_to_save_metadatas):
|
| 564 |
+
|
| 565 |
+
# metadatas will store the metadata of all extracted frames
|
| 566 |
+
metadatas = []
|
| 567 |
+
|
| 568 |
+
# load video using cv2
|
| 569 |
+
video = cv2.VideoCapture(path_to_video)
|
| 570 |
+
# load transcript using webvtt
|
| 571 |
+
trans = webvtt.read(path_to_transcript)
|
| 572 |
+
|
| 573 |
+
# iterate transcript file
|
| 574 |
+
# for each video segment specified in the transcript file
|
| 575 |
+
for idx, transcript in enumerate(trans):
|
| 576 |
+
# get the start time and end time in seconds
|
| 577 |
+
start_time_ms = str2time(transcript.start)
|
| 578 |
+
end_time_ms = str2time(transcript.end)
|
| 579 |
+
# get the time in ms exactly
|
| 580 |
+
# in the middle of start time and end time
|
| 581 |
+
mid_time_ms = (end_time_ms + start_time_ms) / 2
|
| 582 |
+
# get the transcript, remove the next-line symbol
|
| 583 |
+
text = transcript.text.replace("\n", ' ')
|
| 584 |
+
# get frame at the middle time
|
| 585 |
+
video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
|
| 586 |
+
success, frame = video.read()
|
| 587 |
+
if success:
|
| 588 |
+
# if the frame is extracted successfully, resize it
|
| 589 |
+
image = maintain_aspect_ratio_resize(frame, height=350)
|
| 590 |
+
# save frame as JPEG file
|
| 591 |
+
img_fname = f'frame_{idx}.jpg'
|
| 592 |
+
img_fpath = osp.join(
|
| 593 |
+
path_to_save_extracted_frames, img_fname
|
| 594 |
+
)
|
| 595 |
+
cv2.imwrite(img_fpath, image)
|
| 596 |
+
|
| 597 |
+
# prepare the metadata
|
| 598 |
+
metadata = {
|
| 599 |
+
'extracted_frame_path': img_fpath,
|
| 600 |
+
'transcript': text,
|
| 601 |
+
'video_segment_id': idx,
|
| 602 |
+
'video_path': path_to_video,
|
| 603 |
+
'mid_time_ms': mid_time_ms,
|
| 604 |
+
}
|
| 605 |
+
metadatas.append(metadata)
|
| 606 |
+
|
| 607 |
+
else:
|
| 608 |
+
print(f"ERROR! Cannot extract frame: idx = {idx}")
|
| 609 |
+
|
| 610 |
+
# save metadata of all extracted frames
|
| 611 |
+
fn = osp.join(path_to_save_metadatas, 'metadatas.json')
|
| 612 |
+
with open(fn, 'w') as outfile:
|
| 613 |
+
json.dump(metadatas, outfile)
|
| 614 |
+
return metadatas
|
| 615 |
+
|
| 616 |
+
def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
|
| 617 |
+
# output paths to save extracted frames and their metadata
|
| 618 |
+
extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
|
| 619 |
+
metadatas_path = vid_dir
|
| 620 |
+
|
| 621 |
+
# create these output folders if not existing
|
| 622 |
+
Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
|
| 623 |
+
Path(metadatas_path).mkdir(parents=True, exist_ok=True)
|
| 624 |
+
|
| 625 |
+
# call the function to extract frames and metadatas
|
| 626 |
+
metadatas = extract_and_save_frames_and_metadata(
|
| 627 |
+
vid_filepath,
|
| 628 |
+
vid_transcript_filepath,
|
| 629 |
+
extracted_frames_path,
|
| 630 |
+
metadatas_path,
|
| 631 |
+
)
|
| 632 |
+
return metadatas
|
| 633 |
+
|
| 634 |
+
# function extract_and_save_frames_and_metadata_with_fps
|
| 635 |
+
# receives as input a video
|
| 636 |
+
# does extracting and saving frames and their metadatas
|
| 637 |
+
# returns the extracted metadatas
|
| 638 |
+
def extract_and_save_frames_and_metadata_with_fps(
|
| 639 |
+
lvlm_prompt,
|
| 640 |
+
path_to_video,
|
| 641 |
+
path_to_save_extracted_frames,
|
| 642 |
+
path_to_save_metadatas,
|
| 643 |
+
num_of_extracted_frames_per_second=1):
|
| 644 |
+
|
| 645 |
+
# metadatas will store the metadata of all extracted frames
|
| 646 |
+
metadatas = []
|
| 647 |
+
|
| 648 |
+
# load video using cv2
|
| 649 |
+
video = cv2.VideoCapture(path_to_video)
|
| 650 |
+
|
| 651 |
+
# Get the frames per second
|
| 652 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 653 |
+
# Get hop = the number of frames pass before a frame is extracted
|
| 654 |
+
hop = round(fps / num_of_extracted_frames_per_second)
|
| 655 |
+
curr_frame = 0
|
| 656 |
+
idx = -1
|
| 657 |
+
while(True):
|
| 658 |
+
# iterate all frames
|
| 659 |
+
ret, frame = video.read()
|
| 660 |
+
if not ret:
|
| 661 |
+
break
|
| 662 |
+
if curr_frame % hop == 0:
|
| 663 |
+
idx = idx + 1
|
| 664 |
+
|
| 665 |
+
# if the frame is extracted successfully, resize it
|
| 666 |
+
image = maintain_aspect_ratio_resize(frame, height=350)
|
| 667 |
+
# save frame as JPEG file
|
| 668 |
+
img_fname = f'frame_{idx}.jpg'
|
| 669 |
+
img_fpath = osp.join(
|
| 670 |
+
path_to_save_extracted_frames,
|
| 671 |
+
img_fname
|
| 672 |
+
)
|
| 673 |
+
cv2.imwrite(img_fpath, image)
|
| 674 |
+
|
| 675 |
+
# generate caption using lvlm_inference
|
| 676 |
+
b64_image = encode_image(img_fpath)
|
| 677 |
+
caption = lvlm_inference(lvlm_prompt, b64_image)
|
| 678 |
+
|
| 679 |
+
# prepare the metadata
|
| 680 |
+
metadata = {
|
| 681 |
+
'extracted_frame_path': img_fpath,
|
| 682 |
+
'transcript': caption,
|
| 683 |
+
'video_segment_id': idx,
|
| 684 |
+
'video_path': path_to_video,
|
| 685 |
+
}
|
| 686 |
+
metadatas.append(metadata)
|
| 687 |
+
curr_frame += 1
|
| 688 |
+
|
| 689 |
+
# save metadata of all extracted frames
|
| 690 |
+
metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
|
| 691 |
+
with open(metadatas_path, 'w') as outfile:
|
| 692 |
+
json.dump(metadatas, outfile)
|
| 693 |
+
return metadatas
|