88hours commited on
Commit
17934c8
·
verified ·
1 Parent(s): 2047272

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ myenv
2
+ __pycache__
3
+ mm_rag/*
4
+ shared_data
5
+ .gradio
6
+ .env
7
+ .venv
8
+ .github
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: multimodel-rag-chat-with-videos
3
+ app_file: main_demo.py
4
+ sdk: gradio
5
+ sdk_version: 5.17.1
6
+ ---
7
+ # ReArchitecture Multimodal RAG System Pipeline Journey
8
+ I ported it locally and isolated each concept into a step as Python runnable
9
+ It is simplified, refactored and bug-fixed now.
10
+ I migrated from Prediction Guard to HuggingFace.
11
+
12
+ [**Interactive Video Chat Demo and Multimodal RAG System Architecture**](https://learn.deeplearning.ai/courses/multimodal-rag-chat-with-videos/lesson/2/interactive-demo-and-multimodal-rag-system-architecture)
13
+
14
+ ### A multimodal AI system should be able to understand both text and video content.
15
+
16
+ ---
17
+
18
+ ## Step 1 - Learn Gradio (UI) (30 mins)
19
+
20
+ Gradio is a powerful Python library for quickly building browser-based UIs. It supports hot reloading for fast development.
21
+
22
+ ### Key Concepts:
23
+ - **fn**: The function wrapped by the UI.
24
+ - **inputs**: The Gradio components used for input (should match function arguments).
25
+ - **outputs**: The Gradio components used for output (should match return values).
26
+
27
+ 📖 [**Gradio Documentation**](https://www.gradio.app/docs/gradio/introduction)
28
+
29
+ Gradio includes **30+ built-in components**.
30
+
31
+ 💡 **Tip**: For `inputs` and `outputs`, you can pass either:
32
+ - The **component name** as a string (e.g., `"textbox"`)
33
+ - An **instance of the component class** (e.g., `gr.Textbox()`)
34
+
35
+ ### Sharing Your Demo
36
+ ```python
37
+ demo.launch(share=True) # Share your demo with just one extra parameter.
38
+ ```
39
+
40
+ ## Gradio Advanced Features
41
+
42
+ ### **Gradio.Blocks**
43
+ Gradio provides `gr.Blocks`, a flexible way to design web apps with **custom layouts and complex interactions**:
44
+ - Arrange components freely on the page.
45
+ - Handle multiple data flows.
46
+ - Use outputs as inputs for other components.
47
+ - Dynamically update components based on user interaction.
48
+
49
+ ### **Gradio.ChatInterface**
50
+ - Always set `type="messages"` in `gr.ChatInterface`.
51
+ - The default (`type="tuples"`) is **deprecated** and will be removed in future versions.
52
+ - For more UI flexibility, use `gr.ChatBot`.
53
+ - `gr.ChatInterface` supports **Markdown** (not tested yet).
54
+
55
+ ---
56
+
57
+ ## Step 2 - Learn Bridge Tower Embedding Model (Multimodal Learning) (15 mins)
58
+
59
+ Developed in collaboration with Intel, this model maps image-caption pairs into **512-dimensional vectors**.
60
+
61
+ ### Measuring Similarity
62
+ - **Cosine Similarity** → Measures how close images are in vector space (**efficient & commonly used**).
63
+ - **Euclidean Distance** → Uses `cv2.NORM_L2` to compute similarity between two images.
64
+
65
+ ### Converting to 2D for Visualization
66
+ - **UMAP** reduces 512D embeddings to **2D for display purposes**.
67
+
68
+ ## Preprocessing Videos for Multimodal RAG
69
+
70
+ ### **Case 1: WEBVTT → Extracting Text Segments from Video**
71
+ - Converts video + text into structured metadata.
72
+ - Splits content into multiple segments.
73
+
74
+ ### **Case 2: Whisper (Small) → Video Only**
75
+ - Extracts **audio** → `model.transcribe()`.
76
+ - Applies `getSubs()` helper function to retrieve **WEBVTT** subtitles.
77
+ - Uses **Case 1** processing.
78
+
79
+ ### **Case 3: LvLM → Video + Silent/Music Extraction**
80
+ - Uses **Llava (LvLM model)** for **frame-based captioning**.
81
+ - Encodes each frame as a **Base64 image**.
82
+ - Extracts context and captions from video frames.
83
+ - Uses **Case 1** processing.
84
+
85
+ # Step 4 - What is LLaVA?
86
+ LLaVA (Large Language-and-Vision Assistant), a large multimodal model that connects a vision encoder that doesn't just see images but understands them, reads the text embedded in them, and reasons about their context—all.
87
+
88
+ # Step 5 - what is a vector Store?
89
+ A vector store is a specialized database designed to:
90
+
91
+ - Store and manage high-dimensional vector data efficiently
92
+ - Perform similarity-based searches where K=1 returns the most similar result
93
+
94
+ - In LanceDB specifically, store multiple data types:
95
+ . Text content (captions)
96
+ . Image file paths
97
+ . Metadata
98
+ . Vector embeddings
99
+
100
+ ```python
101
+ _ = MultimodalLanceDB.from_text_image_pairs(
102
+ texts=updated_vid1_trans+vid2_trans,
103
+ image_paths=vid1_img_path+vid2_img_path,
104
+ embedding=BridgeTowerEmbeddings(),
105
+ metadatas=vid1_metadata+vid2_metadata,
106
+ connection=db,
107
+ table_name=TBL_NAME,
108
+ mode="overwrite",
109
+ )
110
+ ```
111
+ # Gotchas and Solutions
112
+ Image Processing: When working with base64 encoded images, convert them to PIL.Image format before processing with BridgeTower
113
+ Model Selection: Using BridgeTowerForContrastiveLearning instead of PredictionGuard due to API access limitations
114
+ Model Size: BridgeTower model requires ~3.5GB download
115
+ Image Downloads: Some Flickr images may be unavailable; implement robust error handling
116
+ Token Decoding: BridgeTower contrastive learning model works with embeddings, not token predictions
117
+ Install from git+https://github.com/openai/whisper.git
118
+
119
+ # Install ffmepg using brew
120
+ ```bash
121
+ brew install ffmpeg
122
+ brew link ffmpeg
123
+ ```
124
+
125
+
126
+ # Learning and Skills
127
+
128
+ ## Technical Skills:
129
+
130
+ Basic Machine learning and deep learning
131
+ Vector embeddings and similarity search
132
+ Multimodal data processing
133
+
134
+ ## Framework & Library Expertise:
135
+
136
+ Hugging Face Transformers
137
+ Gradio UI development
138
+ LangChain integration (Basic)
139
+ PyTorch basics
140
+ LanceDB vector storage
141
+
142
+ ## AI/ML Concepts:
143
+
144
+ Multimodal RAG system architecture
145
+ Vector embeddings and similarity search
146
+ Large Language Models (LLaVA)
147
+ Image-text pair processing
148
+ Dimensionality reduction techniques
149
+
150
+
151
+ ## Multimedia Processing:
152
+
153
+ Video frame extraction
154
+ Audio transcription (Whisper)
155
+ Image processing (PIL)
156
+ Base64 encoding/decoding
157
+ WebVTT handling
158
+
159
+ ## System Design:
160
+
161
+ Client-server architecture
162
+ API endpoint design
163
+ Data pipeline construction
164
+ Vector store implementation
165
+ Multimodal system integration
166
+ ## Hugging Face
167
+ Remote: hf_origin
168
+ branch:hf_main
169
+ title: Hg Demo
170
+ emoji: 😻
171
+ colorFrom: gray
172
+ colorTo: red
173
+ sdk: gradio
174
+ sdk_version: 5.18.0
175
+ app_file: app.py
176
+ pinned: false
177
+ license: mit
178
+ short_description: 'A space to keep AI work for demo '
gradio_utils.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import io
3
+ import sys
4
+ import time
5
+ import dataclasses
6
+ from pathlib import Path
7
+ import os
8
+ from enum import auto, Enum
9
+ from typing import List, Tuple, Any
10
+ from utility import prediction_guard_llava_conv
11
+ import lancedb
12
+ from utility import load_json_file
13
+ from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
14
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
15
+ from mm_rag.MLM.client import PredictionGuardClient
16
+ from mm_rag.MLM.lvlm import LVLM
17
+ from PIL import Image
18
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
19
+ from moviepy.video.io.VideoFileClip import VideoFileClip
20
+ from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
21
+
22
+ server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
23
+
24
+ # function to split video at a timestamp
25
+ def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
26
+ timestamp_in_sec = int(timestamp_in_ms / 1000)
27
+ # create output_video_name folder if not exist:
28
+ Path(output_video_path).mkdir(parents=True, exist_ok=True)
29
+ output_video = os.path.join(output_video_path, output_video_name)
30
+ with VideoFileClip(video_path) as video:
31
+ duration = video.duration
32
+ start_time = max(timestamp_in_sec - play_before_sec, 0)
33
+ end_time = min(timestamp_in_sec + play_after_sec, duration)
34
+ new = video.subclip(start_time, end_time)
35
+ new.write_videofile(output_video, audio_codec='aac')
36
+ return output_video
37
+
38
+
39
+ prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
40
+
41
+ # define default rag_chain
42
+ def get_default_rag_chain():
43
+ # declare host file
44
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
45
+ # declare table name
46
+ TBL_NAME = "demo_tbl"
47
+
48
+ # initialize vectorstore
49
+ db = lancedb.connect(LANCEDB_HOST_FILE)
50
+
51
+ # initialize an BridgeTower embedder
52
+ embedder = BridgeTowerEmbeddings()
53
+
54
+ ## Creating a LanceDB vector store
55
+ vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
56
+ ### creating a retriever for the vector store
57
+ retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
58
+
59
+ # initialize a client as PredictionGuardClien
60
+ client = PredictionGuardClient()
61
+ # initialize LVLM with the given client
62
+ lvlm_inference_module = LVLM(client=client)
63
+
64
+ def prompt_processing(input):
65
+ # get the retrieved results and user's query
66
+ retrieved_results, user_query = input['retrieved_results'], input['user_query']
67
+ # get the first retrieved result by default
68
+ retrieved_result = retrieved_results[0]
69
+ # prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
70
+
71
+ # get all metadata of the retrieved video segment
72
+ metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
73
+
74
+ # get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
75
+ transcript = metadata_retrieved_video_segment['transcript']
76
+ frame_path = metadata_retrieved_video_segment['extracted_frame_path']
77
+ return {
78
+ 'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
79
+ 'image' : frame_path,
80
+ 'metadata' : metadata_retrieved_video_segment,
81
+ }
82
+ # initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
83
+ prompt_processing_module = RunnableLambda(prompt_processing)
84
+
85
+ # the output of this new chain will be a dictionary
86
+ mm_rag_chain_with_retrieved_image = (
87
+ RunnableParallel({"retrieved_results": retriever_module ,
88
+ "user_query": RunnablePassthrough()})
89
+ | prompt_processing_module
90
+ | RunnableParallel({'final_text_output': lvlm_inference_module,
91
+ 'input_to_lvlm' : RunnablePassthrough()})
92
+ )
93
+ return mm_rag_chain_with_retrieved_image
94
+
95
+ class SeparatorStyle(Enum):
96
+ """Different separator style."""
97
+ SINGLE = auto()
98
+
99
+ @dataclasses.dataclass
100
+ class GradioInstance:
101
+ """A class that keeps all conversation history."""
102
+ system: str
103
+ roles: List[str]
104
+ messages: List[List[str]]
105
+ offset: int
106
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
107
+ sep: str = "\n"
108
+ sep2: str = None
109
+ version: str = "Unknown"
110
+ path_to_img: str = None
111
+ video_title: str = None
112
+ path_to_video: str = None
113
+ caption: str = None
114
+ mm_rag_chain: Any = None
115
+
116
+ skip_next: bool = False
117
+
118
+ def _template_caption(self):
119
+ out = ""
120
+ if self.caption is not None:
121
+ out = f"The caption associated with the image is '{self.caption}'. "
122
+ return out
123
+
124
+ def get_prompt_for_rag(self):
125
+ messages = self.messages
126
+ assert len(messages) == 2, "length of current conversation should be 2"
127
+ assert messages[1][1] is None, "the first response message of current conversation should be None"
128
+ ret = messages[0][1]
129
+ return ret
130
+
131
+ def get_conversation_for_lvlm(self):
132
+ pg_conv = prediction_guard_llava_conv.copy()
133
+ image_path = self.path_to_img
134
+ b64_img = encode_image(image_path)
135
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
136
+ if msg is None:
137
+ break
138
+ if i == 0:
139
+ pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
140
+ elif i == len(self.messages[self.offset:]) - 2:
141
+ pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
142
+ else:
143
+ pg_conv.append_message(role, [msg])
144
+ return pg_conv
145
+
146
+ def append_message(self, role, message):
147
+ self.messages.append([role, message])
148
+
149
+ def get_images(self, return_pil=False):
150
+ images = []
151
+ if self.path_to_img is not None:
152
+ path_to_image = self.path_to_img
153
+ images.append(path_to_image)
154
+ return images
155
+
156
+ def to_gradio_chatbot(self):
157
+ ret = []
158
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
159
+ if i % 2 == 0:
160
+ if type(msg) is tuple:
161
+ import base64
162
+ from io import BytesIO
163
+ msg, image, image_process_mode = msg
164
+ max_hw, min_hw = max(image.size), min(image.size)
165
+ aspect_ratio = max_hw / min_hw
166
+ max_len, min_len = 800, 400
167
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
168
+ longest_edge = int(shortest_edge * aspect_ratio)
169
+ W, H = image.size
170
+ if H > W:
171
+ H, W = longest_edge, shortest_edge
172
+ else:
173
+ H, W = shortest_edge, longest_edge
174
+ image = image.resize((W, H))
175
+ buffered = BytesIO()
176
+ image.save(buffered, format="JPEG")
177
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
178
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
179
+ msg = img_str + msg.replace('<image>', '').strip()
180
+ ret.append([msg, None])
181
+ else:
182
+ ret.append([msg, None])
183
+ else:
184
+ ret[-1][-1] = msg
185
+ return ret
186
+
187
+ def copy(self):
188
+ return GradioInstance(
189
+ system=self.system,
190
+ roles=self.roles,
191
+ messages=[[x, y] for x, y in self.messages],
192
+ offset=self.offset,
193
+ sep_style=self.sep_style,
194
+ sep=self.sep,
195
+ sep2=self.sep2,
196
+ version=self.version,
197
+ mm_rag_chain=self.mm_rag_chain,
198
+ )
199
+
200
+ def dict(self):
201
+ return {
202
+ "system": self.system,
203
+ "roles": self.roles,
204
+ "messages": self.messages,
205
+ "offset": self.offset,
206
+ "sep": self.sep,
207
+ "sep2": self.sep2,
208
+ "path_to_img": self.path_to_img,
209
+ "video_title" : self.video_title,
210
+ "path_to_video": self.path_to_video,
211
+ "caption" : self.caption,
212
+ }
213
+ def get_path_to_subvideos(self):
214
+ if self.video_title is not None and self.path_to_img is not None:
215
+ info = video_helper_map[self.video_title]
216
+ path = info['path']
217
+ prefix = info['prefix']
218
+ vid_index = self.path_to_img.split('/')[-1]
219
+ vid_index = vid_index.split('_')[-1]
220
+ vid_index = vid_index.replace('.jpg', '')
221
+ ret = f"{prefix}{vid_index}.mp4"
222
+ ret = os.path.join(path, ret)
223
+ return ret
224
+ elif self.path_to_video is not None:
225
+ return self.path_to_video
226
+ return None
227
+
228
+ def get_gradio_instance(mm_rag_chain=None):
229
+ if mm_rag_chain is None:
230
+ mm_rag_chain = get_default_rag_chain()
231
+
232
+ instance = GradioInstance(
233
+ system="",
234
+ roles=prediction_guard_llava_conv.roles,
235
+ messages=[],
236
+ offset=0,
237
+ sep_style=SeparatorStyle.SINGLE,
238
+ sep="\n",
239
+ path_to_img=None,
240
+ video_title=None,
241
+ caption=None,
242
+ mm_rag_chain=mm_rag_chain,
243
+ )
244
+ return instance
245
+
246
+ gr.set_static_paths(paths=["./assets/"])
247
+ theme = gr.themes.Base(
248
+ primary_hue=gr.themes.Color(
249
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
250
+ secondary_hue=gr.themes.Color(
251
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
252
+ ).set(
253
+ body_background_fill_dark='*primary_950',
254
+ body_text_color_dark='*neutral_300',
255
+ border_color_accent='*primary_700',
256
+ border_color_accent_dark='*neutral_800',
257
+ block_background_fill_dark='*primary_950',
258
+ block_border_width='2px',
259
+ block_border_width_dark='2px',
260
+ button_primary_background_fill_dark='*primary_500',
261
+ button_primary_border_color_dark='*primary_500'
262
+ )
263
+
264
+ css='''
265
+ @font-face {
266
+ font-family: IntelOne;
267
+ src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
268
+ }
269
+ .gradio-container {background-color: #0a0c2b}
270
+ table {
271
+ border-collapse: collapse;
272
+ border: none;
273
+ }
274
+ '''
275
+
276
+ ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
277
+
278
+ # html_title = '''
279
+ # <table style="bordercolor=#0a0c2b; border=0">
280
+ # <tr style="height:150px; border:0">
281
+ # <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
282
+ # <td style="vertical-align:bottom; border:0">
283
+ # <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
284
+ # Multimodal RAG:
285
+ # <br>
286
+ # Chat with Videos
287
+ # </p>
288
+ # </td>
289
+ # <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
290
+
291
+ # <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
292
+ # <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
293
+ # </tr>
294
+ # </table>
295
+
296
+ # '''
297
+
298
+ html_title = '''
299
+ <table style="bordercolor=#0a0c2b; border=0">
300
+ <tr style="height:150px; border:0">
301
+ <td style="border:0"><img src="/file=./assets/header.png"></td>
302
+ </tr>
303
+ </table>
304
+
305
+ '''
306
+
307
+ #<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
308
+ dropdown_list = [
309
+ "What is the name of one of the astronauts?",
310
+ "An astronaut's spacewalk",
311
+ "What does the astronaut say?",
312
+
313
+ ]
314
+
315
+ no_change_btn = gr.Button()
316
+ enable_btn = gr.Button(interactive=True)
317
+ disable_btn = gr.Button(interactive=False)
318
+
319
+ def clear_history(state, request: gr.Request):
320
+ state = get_gradio_instance(state.mm_rag_chain)
321
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
322
+
323
+ def add_text(state, text, request: gr.Request):
324
+ if len(text) <= 0 :
325
+ state.skip_next = True
326
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
327
+
328
+ text = text[:1536] # Hard cut-off
329
+
330
+ state.append_message(state.roles[0], text)
331
+ state.append_message(state.roles[1], None)
332
+ state.skip_next = False
333
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
334
+
335
+ def http_bot(
336
+ state, request: gr.Request
337
+ ):
338
+ start_tstamp = time.time()
339
+
340
+ if state.skip_next:
341
+ # This generate call is skipped due to invalid inputs
342
+ path_to_sub_videos = state.get_path_to_subvideos()
343
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
344
+ return
345
+
346
+ if len(state.messages) == state.offset + 2:
347
+ # First round of conversation
348
+ new_state = get_gradio_instance(state.mm_rag_chain)
349
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
350
+ new_state.append_message(new_state.roles[1], None)
351
+ state = new_state
352
+
353
+ all_images = state.get_images(return_pil=False)
354
+
355
+ # Make requests
356
+ is_very_first_query = True
357
+ if len(all_images) == 0:
358
+ # first query need to do RAG
359
+ # Construct prompt
360
+ prompt_or_conversation = state.get_prompt_for_rag()
361
+ else:
362
+ # subsequence queries, no need to do Retrieval
363
+ is_very_first_query = False
364
+ prompt_or_conversation = state.get_conversation_for_lvlm()
365
+
366
+ if is_very_first_query:
367
+ executor = state.mm_rag_chain
368
+ else:
369
+ executor = lvlm_inference_with_conversation
370
+
371
+ state.messages[-1][-1] = "▌"
372
+ path_to_sub_videos = state.get_path_to_subvideos()
373
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
374
+
375
+ try:
376
+ if is_very_first_query:
377
+ # get response by invoke executor chain
378
+ response = executor.invoke(prompt_or_conversation)
379
+ message = response['final_text_output']
380
+ if 'metadata' in response['input_to_lvlm']:
381
+ metadata = response['input_to_lvlm']['metadata']
382
+ if (state.path_to_img is None
383
+ and 'input_to_lvlm' in response
384
+ and 'image' in response['input_to_lvlm']
385
+ ):
386
+ state.path_to_img = response['input_to_lvlm']['image']
387
+
388
+ if state.path_to_video is None and 'video_path' in metadata:
389
+ video_path = metadata['video_path']
390
+ mid_time_ms = metadata['mid_time_ms']
391
+ splited_video_path = split_video(video_path, mid_time_ms)
392
+ state.path_to_video = splited_video_path
393
+
394
+ if state.caption is None and 'transcript' in metadata:
395
+ state.caption = metadata['transcript']
396
+ else:
397
+ raise ValueError("Response's format is changed")
398
+ else:
399
+ # get the response message by directly call PredictionGuardAPI
400
+ message = executor(prompt_or_conversation)
401
+
402
+ except Exception as e:
403
+ print(e)
404
+ state.messages[-1][-1] = server_error_msg
405
+ yield (state, state.to_gradio_chatbot(), None) + (
406
+ enable_btn,
407
+ )
408
+ return
409
+
410
+ state.messages[-1][-1] = message
411
+ path_to_sub_videos = state.get_path_to_subvideos()
412
+ # path_to_image = state.path_to_img
413
+ # caption = state.caption
414
+ # # print(path_to_sub_videos)
415
+ # # print(path_to_image)
416
+ # # print('caption: ', caption)
417
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
418
+
419
+ finish_tstamp = time.time()
420
+ return
421
+
422
+ def get_demo(rag_chain=None):
423
+ if rag_chain is None:
424
+ rag_chain = get_default_rag_chain()
425
+
426
+ with gr.Blocks(theme=theme, css=css) as demo:
427
+ # gr.Markdown(description)
428
+ instance = get_gradio_instance(rag_chain)
429
+ state = gr.State(instance)
430
+ demo.load(
431
+ None,
432
+ None,
433
+ js="""
434
+ () => {
435
+ const params = new URLSearchParams(window.location.search);
436
+ if (!params.has('__theme')) {
437
+ params.set('__theme', 'dark');
438
+ window.location.search = params.toString();
439
+ }
440
+ }""",
441
+ )
442
+ gr.HTML(value=html_title)
443
+ with gr.Row():
444
+ with gr.Column(scale=4):
445
+ video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
446
+ with gr.Column(scale=7):
447
+ chatbot = gr.Chatbot(
448
+ elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
449
+ )
450
+ with gr.Row():
451
+ with gr.Column(scale=8):
452
+ # textbox.render()
453
+ textbox = gr.Dropdown(
454
+ dropdown_list,
455
+ allow_custom_value=True,
456
+ # show_label=False,
457
+ # container=False,
458
+ label="Query",
459
+ info="Enter your query here or choose a sample from the dropdown list!"
460
+ )
461
+ with gr.Column(scale=1, min_width=50):
462
+ submit_btn = gr.Button(
463
+ value="Send", variant="primary", interactive=True
464
+ )
465
+ with gr.Row(elem_id="buttons") as button_row:
466
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
467
+
468
+ btn_list = [clear_btn]
469
+
470
+ clear_btn.click(
471
+ clear_history, [state], [state, chatbot, textbox, video] + btn_list
472
+ )
473
+ submit_btn.click(
474
+ add_text,
475
+ [state, textbox],
476
+ [state, chatbot, textbox,] + btn_list,
477
+ ).then(
478
+ http_bot,
479
+ [state],
480
+ [state, chatbot, video] + btn_list,
481
+ )
482
+ return demo
483
+
lrn_vector_embeddings.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import numpy as np
4
+ from numpy.linalg import norm
5
+ import cv2
6
+ from io import StringIO, BytesIO
7
+ from umap import UMAP
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ import base64
12
+ from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM
13
+ import requests
14
+ from PIL import Image
15
+ import torch
16
+
17
+ url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
18
+ cap1='A motorcycle sits parked across from a herd of livestock'
19
+
20
+ url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
21
+ cap2='Motorcycle on platform to be worked on in garage'
22
+
23
+ url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
24
+ cap3='a cat laying down stretched out near a laptop'
25
+
26
+ img1 = {
27
+ 'flickr_url': url1,
28
+ 'caption': cap1,
29
+ 'image_path' : './shared_data/motorcycle_1.jpg'
30
+ }
31
+
32
+ img2 = {
33
+ 'flickr_url': url2,
34
+ 'caption': cap2,
35
+ 'image_path' : './shared_data/motorcycle_2.jpg'
36
+ }
37
+
38
+ img3 = {
39
+ 'flickr_url' : url3,
40
+ 'caption': cap3,
41
+ 'image_path' : './shared_data/cat_1.jpg'
42
+ }
43
+
44
+ def bt_embeddings_from_local(text, image):
45
+
46
+ model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
47
+ processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
48
+
49
+ processed_inputs = processor(image, text, padding=True, return_tensors="pt")
50
+
51
+ #inputs = processor(prompt, base64_image, padding=True, return_tensors="pt")
52
+ outputs = model(**processed_inputs)
53
+
54
+ cross_modal_embeddings = outputs.cross_embeds
55
+ text_embeddings = outputs.text_embeds
56
+ image_embeddings = outputs.image_embeds
57
+ return {
58
+ 'cross_modal_embeddings': cross_modal_embeddings,
59
+ 'text_embeddings': text_embeddings,
60
+ 'image_embeddings': image_embeddings
61
+ }
62
+
63
+
64
+ def bt_scores_with_image_and_text_retrieval():
65
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
66
+ image = Image.open(requests.get(url, stream=True).raw)
67
+ texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
68
+
69
+ processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
70
+ model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
71
+
72
+ # forward pass
73
+ scores = dict()
74
+ for text in texts:
75
+ # prepare inputs
76
+ encoding = processor(image, text, return_tensors="pt")
77
+ outputs = model(**encoding)
78
+ scores[text] = outputs.logits[0,1].item()
79
+ return scores
80
+
81
+
82
+ def bt_with_masked_input():
83
+ url = "http://images.cocodataset.org/val2017/000000360943.jpg"
84
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
85
+ text = "a <mask> looking out of the window"
86
+
87
+
88
+ processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
89
+ model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-gaudi")
90
+
91
+ # prepare inputs
92
+ encoding = processor(image, text, return_tensors="pt")
93
+
94
+ # forward pass
95
+ outputs = model(**encoding)
96
+
97
+ token_ids = outputs.logits.argmax(dim=-1).squeeze(0).tolist()
98
+ if isinstance(token_ids, list):
99
+ results = processor.tokenizer.decode(token_ids)
100
+ else:
101
+ results = processor.tokenizer.decode([token_ids])
102
+
103
+ print(results)
104
+ return results
105
+
106
+ if __name__ == "__main__":
107
+ #res = bt_embeddingsl()
108
+ #print((res['text_embeddings']))
109
+ for img in [img1, img2, img3]:
110
+ embeddings = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
111
+ print(embeddings['cross_modal_embeddings'][0].shape)
main_demo.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import gradio as gr
3
+ import os
4
+ from PIL import Image
5
+ import ollama
6
+ from utility import download_video, get_transcript_vtt, extract_meta_data
7
+ from mm_rag.embeddings.bridgetower_embeddings import (
8
+ BridgeTowerEmbeddings
9
+ )
10
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
11
+ import lancedb
12
+ import json
13
+ import os
14
+ from PIL import Image
15
+ from utility import load_json_file, display_retrieved_results
16
+ import pyarrow as pa
17
+
18
+ # declare host file
19
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
20
+ # declare table name
21
+ TBL_NAME = "demo_tbl"
22
+ # initialize vectorstore
23
+ db = lancedb.connect(LANCEDB_HOST_FILE)
24
+ # initialize an BridgeTower embedder
25
+ embedder = BridgeTowerEmbeddings()
26
+
27
+ vid_dir = "./shared_data/videos/yt_video"
28
+ Path(vid_dir).mkdir(parents=True, exist_ok=True)
29
+
30
+
31
+ def open_table():
32
+ # open a connection to table TBL_NAME
33
+ tbl = db.open_table(TBL_NAME)
34
+
35
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
36
+ # display the first 3 rows of the table
37
+ tbl.to_pandas()[['text', 'image_path']].head(3)
38
+
39
+ def store_in_rag():
40
+
41
+ # load metadata files
42
+ vid_metadata_path = './shared_data/videos/yt_video/metadatas.json'
43
+ vid_metadata = load_json_file(vid_metadata_path)
44
+
45
+
46
+ vid_subs = [vid['transcript'] for vid in vid_metadata]
47
+ vid_img_path = [vid['extracted_frame_path'] for vid in vid_metadata]
48
+
49
+
50
+ # for video1, we pick n = 7
51
+ n = 7
52
+ updated_vid_subs = [
53
+ ' '.join(vid_subs[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
54
+ ' '.join(vid_subs[0 : i + int(n/2)]) for i in range(len(vid_subs))
55
+ ]
56
+
57
+ # also need to update the updated transcripts in metadata
58
+ for i in range(len(updated_vid_subs)):
59
+ vid_metadata[i]['transcript'] = updated_vid_subs[i]
60
+
61
+ # you can pass in mode="append"
62
+ # to add more entries to the vector store
63
+ # in case you want to start with a fresh vector store,
64
+ # you can pass in mode="overwrite" instead
65
+
66
+ _ = MultimodalLanceDB.from_text_image_pairs(
67
+ texts=updated_vid_subs,
68
+ image_paths=vid_img_path,
69
+ embedding=embedder,
70
+ metadatas=vid_metadata,
71
+ connection=db,
72
+ table_name=TBL_NAME,
73
+ mode="overwrite",
74
+ )
75
+
76
+ def get_metadata_of_yt_video_with_captions(vid_url):
77
+ vid_filepath = download_video(vid_url, vid_dir)
78
+ vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
79
+ extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath) #should return lowercase file name without spaces
80
+ store_in_rag()
81
+ open_table()
82
+ return vid_filepath
83
+
84
+ """
85
+ def chat_response_llvm(instruction):
86
+ #file_path = the_metadatas[0]
87
+ file_path = 'shared_data/videos/yt_video/extracted_frame/'
88
+ result = ollama.generate(
89
+ model='llava',
90
+ prompt=instruction,
91
+ images=[file_path],
92
+ stream=True
93
+ )['response']
94
+ return result
95
+ """
96
+
97
+ def return_top_k_most_similar_docs(query="show me a group of astronauts", max_docs=1):
98
+ # ask to return top 3 most similar documents
99
+ # Creating a LanceDB vector store
100
+ vectorstore = MultimodalLanceDB(
101
+ uri=LANCEDB_HOST_FILE,
102
+ embedding=embedder,
103
+ table_name=TBL_NAME)
104
+
105
+ # creating a retriever for the vector store
106
+ # search_type="similarity"
107
+ # declares that the type of search that the Retriever should perform
108
+ # is similarity search
109
+ # search_kwargs={"k": 1} means returning top-1 most similar document
110
+
111
+
112
+ retriever = vectorstore.as_retriever(
113
+ search_type='similarity',
114
+ search_kwargs={"k": max_docs})
115
+
116
+ results = retriever.invoke(query)
117
+ return results[0].page_content, Image.open(results[0].metadata['extracted_frame_path'])
118
+
119
+
120
+ def process_url_and_init(youtube_url):
121
+ vid_filepath = get_metadata_of_yt_video_with_captions(youtube_url)
122
+ return vid_filepath
123
+
124
+ def init_ui():
125
+ with gr.Blocks() as demo:
126
+ url_input = gr.Textbox(label="Enter YouTube URL", value="https://www.youtube.com/watch?v=7Hcg-rLYwdM", interactive=False)
127
+ submit_btn = gr.Button("Process Video")
128
+ #vid_filepath = 'shared_data/videos/yt_video/Welcome_back_to_Planet_Earth.mp4'
129
+ chatbox = gr.Textbox(label="What question do you want to ask?", value="show me a group of astronauts")
130
+ response = gr.Textbox(label="Response", interactive=False)
131
+ video = gr.Video()
132
+ frame = gr.Image()
133
+ submit_btn2 = gr.Button("ASK")
134
+
135
+ submit_btn.click(fn=process_url_and_init, inputs=url_input, outputs=[video])
136
+ submit_btn2.click(fn=return_top_k_most_similar_docs, inputs=[chatbox], outputs=[response, frame])
137
+ return demo
138
+
139
+ if __name__ == '__main__':
140
+ demo = init_ui()
141
+ demo.launch(True)
142
+
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ langchain-predictionguard
3
+ IPython
4
+ umap-learn
5
+ pytubefix
6
+ youtube_transcript_api
7
+ torch
8
+ transformers
9
+ matplotlib
10
+ seaborn
11
+ datasets
12
+ moviepy
13
+ whisper
14
+ webvtt-py
15
+ tqdm
16
+ lancedb
17
+ langchain-core
18
+ langchain-community
19
+ ollama
20
+ opencv-python
21
+ openai-whisper
22
+ huggingface_hub[cli]
s2_download_data.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from IPython.display import display
4
+
5
+ # You can use your own uploaded images and captions.
6
+ # You will be responsible for the legal use of images that
7
+ # you are going to use.
8
+
9
+ url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
10
+ cap1='A motorcycle sits parked across from a herd of livestock'
11
+
12
+ url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
13
+ cap2='Motorcycle on platform to be worked on in garage'
14
+
15
+ url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
16
+ cap3='a cat laying down stretched out near a laptop'
17
+
18
+ img1 = {
19
+ 'flickr_url': url1,
20
+ 'caption': cap1,
21
+ 'image_path' : './shared_data/motorcycle_1.jpg'
22
+ }
23
+
24
+ img2 = {
25
+ 'flickr_url': url2,
26
+ 'caption': cap2,
27
+ 'image_path' : './shared_data/motorcycle_2.jpg'
28
+ }
29
+
30
+ img3 = {
31
+ 'flickr_url' : url3,
32
+ 'caption': cap3,
33
+ 'image_path' : './shared_data/cat_1.jpg'
34
+ }
35
+
36
+ def download_images():
37
+ # download images
38
+ imgs = [img1, img2, img3]
39
+ for img in imgs:
40
+ data = requests.get(img['flickr_url']).content
41
+ with open(img['image_path'], 'wb') as f:
42
+ f.write(data)
43
+
44
+ for img in [img1, img2, img3]:
45
+ image = Image.open(img['image_path'])
46
+ caption = img['caption']
47
+ display(image)
48
+ print(caption)
49
+
s3_data_to_vector_embedding.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy.linalg import norm
2
+ from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
3
+ import torch
4
+ from PIL import Image
5
+
6
+
7
+ url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
8
+ cap1='A motorcycle sits parked across from a herd of livestock'
9
+
10
+ url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
11
+ cap2='Motorcycle on platform to be worked on in garage'
12
+
13
+ url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
14
+ cap3='a cat laying down stretched out near a laptop'
15
+
16
+ img1 = {
17
+ 'flickr_url': url1,
18
+ 'caption': cap1,
19
+ 'image_path' : './shared_data/motorcycle_1.jpg',
20
+ 'tensor_path' : './shared_data/motorcycle_1'
21
+ }
22
+
23
+ img2 = {
24
+ 'flickr_url': url2,
25
+ 'caption': cap2,
26
+ 'image_path' : './shared_data/motorcycle_2.jpg',
27
+ 'tensor_path' : './shared_data/motorcycle_2'
28
+ }
29
+
30
+ img3 = {
31
+ 'flickr_url' : url3,
32
+ 'caption': cap3,
33
+ 'image_path' : './shared_data/cat_1.jpg',
34
+ 'tensor_path' : './shared_data/cat_1'
35
+ }
36
+
37
+ def bt_embeddings_from_local(text, image):
38
+
39
+ model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
40
+ processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
41
+
42
+ processed_inputs = processor(image, text, padding=True, return_tensors="pt")
43
+
44
+ outputs = model(**processed_inputs)
45
+
46
+ cross_modal_embeddings = outputs.cross_embeds
47
+ text_embeddings = outputs.text_embeds
48
+ image_embeddings = outputs.image_embeds
49
+ return {
50
+ 'cross_modal_embeddings': cross_modal_embeddings,
51
+ 'text_embeddings': text_embeddings,
52
+ 'image_embeddings': image_embeddings
53
+ }
54
+
55
+ def save_embeddings():
56
+ for img in [img1, img2, img3]:
57
+ embedding = bt_embeddings_from_local(img['caption'], Image.open(img['image_path']))
58
+ print(embedding['cross_modal_embeddings'][0].shape) #<class 'torch.Tensor'>
59
+ torch.save(embedding['cross_modal_embeddings'][0], img['tensor_path'] + '.pt')
60
+
61
+
s4_calculate_distance.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numpy.linalg import norm
3
+ import torch
4
+ from IPython.display import display
5
+ import cv2
6
+
7
+ url1='http://farm3.staticflickr.com/2519/4126738647_cc436c111b_z.jpg'
8
+ cap1='A motorcycle sits parked across from a herd of livestock'
9
+
10
+ url2='http://farm3.staticflickr.com/2046/2003879022_1b4b466d1d_z.jpg'
11
+ cap2='Motorcycle on platform to be worked on in garage'
12
+
13
+ url3='https://i.natgeofe.com/n/548467d8-c5f1-4551-9f58-6817a8d2c45e/NationalGeographic_2572187_3x2.jpg'
14
+ cap3='a cat laying down stretched out near a laptop'
15
+
16
+ img1 = {
17
+ 'flickr_url': url1,
18
+ 'caption': cap1,
19
+ 'image_path' : './shared_data/motorcycle_1.jpg',
20
+ 'tensor_path' : './shared_data/motorcycle_1'
21
+ }
22
+
23
+ img2 = {
24
+ 'flickr_url': url2,
25
+ 'caption': cap2,
26
+ 'image_path' : './shared_data/motorcycle_2.jpg',
27
+ 'tensor_path' : './shared_data/motorcycle_2'
28
+ }
29
+
30
+ img3 = {
31
+ 'flickr_url' : url3,
32
+ 'caption': cap3,
33
+ 'image_path' : './shared_data/cat_1.jpg',
34
+ 'tensor_path' : './shared_data/cat_1'
35
+ }
36
+
37
+ def load_tensor(path):
38
+ return torch.load(path)
39
+
40
+ def load_embeddings():
41
+ ex1_embed = load_tensor(img1['tensor_path'] + '.pt')
42
+ ex2_embed = load_tensor(img2['tensor_path'] + '.pt')
43
+ ex3_embed = load_tensor(img3['tensor_path'] + '.pt')
44
+ return ex1_embed.data.numpy(), ex2_embed.data.numpy(), ex3_embed.data.numpy()
45
+
46
+ def cosine_similarity(vec1, vec2):
47
+ similarity = np.dot(vec1,vec2)/(norm(vec1)*norm(vec2))
48
+ return similarity
49
+
50
+ def calculate_cosine_distance():
51
+ ex1_embed, ex2_embed, ex3_embed = load_embeddings()
52
+ similarity1 = cosine_similarity(ex1_embed, ex2_embed)
53
+ similarity2 = cosine_similarity(ex1_embed, ex3_embed)
54
+ similarity3 = cosine_similarity(ex2_embed, ex3_embed)
55
+ return [similarity1, similarity2, similarity3]
56
+
57
+ def calcuate_euclidean_distance():
58
+ ex1_embed, ex2_embed, ex3_embed = load_embeddings()
59
+ distance1 = cv2.norm(ex1_embed,ex2_embed, cv2.NORM_L2)
60
+ distance2 = cv2.norm(ex1_embed,ex3_embed, cv2.NORM_L2)
61
+ distance3 = cv2.norm(ex2_embed,ex3_embed, cv2.NORM_L2)
62
+ return [distance1, distance2, distance3]
63
+
64
+ def show_cosine_distance():
65
+ distances = calculate_cosine_distance()
66
+ print("Cosine similarity between ex1_embeded and ex2_embeded is:")
67
+ display(distances[0])
68
+ print("Cosine similarity between ex1_embeded and ex3_embeded is:")
69
+ display(distances[1])
70
+ print("Cosine similarity between ex2_embeded and ex2_embeded is:")
71
+ display(distances[2])
72
+
73
+ def show_euclidean_distance():
74
+ distances = calcuate_euclidean_distance()
75
+ print("Euclidean distance between ex1_embeded and ex2_embeded is:")
76
+ display(distances[0])
77
+ print("Euclidean distance between ex1_embeded and ex3_embeded is:")
78
+ display(distances[1])
79
+ print("Euclidean distance between ex2_embeded and ex2_embeded is:")
80
+ display(distances[2])
81
+
82
+ show_cosine_distance()
83
+ show_euclidean_distance()
s5-how-to-umap.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from IPython.display import display
3
+ from umap import UMAP
4
+ from sklearn.preprocessing import MinMaxScaler
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from s3_data_to_vector_embedding import bt_embeddings_from_local
10
+ import random
11
+ import numpy as np
12
+ import torch
13
+ from sklearn.model_selection import train_test_split
14
+ from datasets import load_dataset
15
+
16
+ # prompt templates
17
+ templates = [
18
+ 'a picture of {}',
19
+ 'an image of {}',
20
+ 'a nice {}',
21
+ 'a beautiful {}',
22
+ ]
23
+ # function helps to prepare list image-text pairs from the first [test_size] data
24
+ def data_prep(hf_dataset_name, templates=templates, test_size=1000):
25
+ # load Huggingface dataset by streaming the dataset which doesn’t download anything, and lets you use it instantly
26
+ #dataset = load_dataset(hf_dataset_name, trust_remote_code=True, split='train', streaming=True)
27
+
28
+ dataset = load_dataset(hf_dataset_name)
29
+ # split dataset with specific test_size
30
+ train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
31
+ test_dataset = train_test_dataset['test']
32
+ print(test_dataset)
33
+ # get the test dataset
34
+ img_txt_pairs = []
35
+ for i in range(len(test_dataset)):
36
+ img_txt_pairs.append({
37
+ 'caption' : templates[random.randint(0, len(templates)-1)],
38
+ 'pil_img' : test_dataset[i]['image']
39
+ })
40
+ return img_txt_pairs
41
+
42
+
43
+
44
+ def load_all_dataset():
45
+
46
+ car_img_txt_pairs = data_prep("tanganke/stanford_cars", test_size=50)
47
+ cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset", test_size=50)
48
+
49
+ return cat_img_txt_pairs, car_img_txt_pairs
50
+ # compute BridgeTower embeddings for cat image-text pairs
51
+ def load_cat_and_car_embeddings():
52
+ # prepare image_text pairs
53
+ cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
54
+ def save_embeddings(embedding, path):
55
+ torch.save(embedding, path)
56
+
57
+ def load_embeddings(img_txt_pair):
58
+ pil_img = img_txt_pair['pil_img']
59
+ caption = img_txt_pair['caption']
60
+ return bt_embeddings_from_local(caption, pil_img)
61
+
62
+ def load_all_embeddings_from_image_text_pairs(img_txt_pairs, file_name):
63
+ embeddings = []
64
+ for img_txt_pair in tqdm(
65
+ img_txt_pairs,
66
+ total=len(img_txt_pairs)
67
+ ):
68
+
69
+ embedding = load_embeddings(img_txt_pair)
70
+ print(embedding)
71
+ cross_modal_embeddings = embedding['cross_modal_embeddings'][0].detach().numpy() #this is not the right way to convert tensor to numpy
72
+ #print(cross_modal_embeddings.shape) #<class 'torch.Tensor'>
73
+ #save_embeddings(cross_modal_embeddings, file_name)
74
+ embeddings.append(cross_modal_embeddings)
75
+ return cross_modal_embeddings
76
+
77
+
78
+ cat_embeddings = load_all_embeddings_from_image_text_pairs(cat_img_txt_pairs, './shared_data/cat_embeddings.pt')
79
+ car_embeddings = load_all_embeddings_from_image_text_pairs(car_img_txt_pairs, './shared_data/car_embeddings.pt')
80
+
81
+ return cat_embeddings, car_embeddings
82
+
83
+
84
+ # function transforms high-dimension vectors to 2D vectors using UMAP
85
+ def dimensionality_reduction(embeddings, labels):
86
+
87
+
88
+ print(embeddings)
89
+ X_scaled = MinMaxScaler().fit_transform(embeddings.reshape(-1, 1)) # This is not the right way to scale the data
90
+ mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled)
91
+ df_emb = pd.DataFrame(mapper.embedding_, columns=["X", "Y"])
92
+ df_emb["label"] = labels
93
+ print(df_emb)
94
+ return df_emb
95
+
96
+ def show_umap_visualization():
97
+ def reduce_dimensions():
98
+ cat_embeddings, car_embeddings = load_cat_and_car_embeddings()
99
+ # stacking embeddings of cat and car examples into one numpy array
100
+ all_embeddings = np.concatenate([cat_embeddings, car_embeddings]) # This is not the right way to scale the data
101
+
102
+ # prepare labels for the 3 examples
103
+ labels = ['cat'] * len(cat_embeddings) + ['car'] * len(car_embeddings)
104
+
105
+ # compute dimensionality reduction for the 3 examples
106
+ reduced_dim_emb = dimensionality_reduction(all_embeddings, labels)
107
+ return reduced_dim_emb
108
+
109
+ reduced_dim_emb = reduce_dimensions()
110
+ # Plot the centroids against the cluster
111
+ fig, ax = plt.subplots(figsize=(8,6)) # Set figsize
112
+
113
+ sns.set_style("whitegrid", {'axes.grid' : False})
114
+ sns.scatterplot(data=reduced_dim_emb,
115
+ x=reduced_dim_emb['X'],
116
+ y=reduced_dim_emb['Y'],
117
+ hue='label',
118
+ palette='bright')
119
+ sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
120
+ plt.title('Scatter plot of images of cats and cars using UMAP')
121
+ plt.xlabel('X')
122
+ plt.ylabel('Y')
123
+ plt.show()
124
+
125
+ def an_example_of_cat_and_car_pair_data():
126
+ cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset()
127
+ # display an example of a cat image-text pair data
128
+ display(cat_img_txt_pairs[0]['caption'])
129
+ display(cat_img_txt_pairs[0]['pil_img'])
130
+
131
+ # display an example of a car image-text pair data
132
+ display(car_img_txt_pairs[0]['caption'])
133
+ display(car_img_txt_pairs[0]['pil_img'])
134
+
135
+
136
+ if __name__ == '__main__':
137
+ show_umap_visualization()
s6_prepare_video_input.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from os import path as osp
4
+ import whisper
5
+ from moviepy import VideoFileClip
6
+ from PIL import Image
7
+ from utility import download_video, extract_meta_data, get_transcript_vtt, getSubs
8
+ from urllib.request import urlretrieve
9
+ from IPython.display import display
10
+ import ollama
11
+
12
+ def demp_video_input_that_has_transcript():
13
+ # first video's url
14
+ vid_url = "https://www.youtube.com/watch?v=7Hcg-rLYwdM"
15
+
16
+ # download Youtube video to ./shared_data/videos/video1
17
+ vid_dir = "./shared_data/videos/video1"
18
+ vid_filepath = download_video(vid_url, vid_dir)
19
+
20
+ # download Youtube video's subtitle to ./shared_data/videos/video1
21
+ vid_transcript_filepath = get_transcript_vtt(vid_url, vid_dir)
22
+
23
+ return extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath)
24
+
25
+ def demp_video_input_that_has_no_transcript():
26
+ # second video's url
27
+ vid_url=(
28
+ "https://multimedia-commons.s3-us-west-2.amazonaws.com/"
29
+ "data/videos/mp4/010/a07/010a074acb1975c4d6d6e43c1faeb8.mp4"
30
+ )
31
+ vid_dir = "./shared_data/videos/video2"
32
+ vid_name = "toddler_in_playground.mp4"
33
+
34
+ # create folder to which video2 will be downloaded
35
+ Path(vid_dir).mkdir(parents=True, exist_ok=True)
36
+ vid_filepath = urlretrieve(
37
+ vid_url,
38
+ osp.join(vid_dir, vid_name)
39
+ )[0]
40
+
41
+ path_to_video_no_transcript = vid_filepath
42
+
43
+ # declare where to save .mp3 audio
44
+ path_to_extracted_audio_file = os.path.join(vid_dir, 'audio.mp3')
45
+
46
+ # extract mp3 audio file from mp4 video video file
47
+ clip = VideoFileClip(path_to_video_no_transcript)
48
+ clip.audio.write_audiofile(path_to_extracted_audio_file)
49
+
50
+ model = whisper.load_model("small")
51
+ options = dict(task="translate", best_of=1, language='en')
52
+ results = model.transcribe(path_to_extracted_audio_file, **options)
53
+
54
+ vtt = getSubs(results["segments"], "vtt")
55
+
56
+ # path to save generated transcript of video1
57
+ path_to_generated_trans = osp.join(vid_dir, 'generated_video1.vtt')
58
+ # write transcription to file
59
+ with open(path_to_generated_trans, 'w') as f:
60
+ f.write(vtt)
61
+
62
+ return extract_meta_data(vid_dir, vid_filepath, path_to_generated_trans)
63
+
64
+
65
+
66
+ def ask_llvm(instruction, file_path):
67
+ result = ollama.generate(
68
+ model='llava',
69
+ prompt=instruction,
70
+ images=[file_path],
71
+ stream=False
72
+ )['response']
73
+ img=Image.open(file_path, mode='r')
74
+ img = img.resize([int(i/1.2) for i in img.size])
75
+ display(img)
76
+ for i in result.split('.'):
77
+ print(i, end='', flush=True)
78
+ if __name__ == "__main__":
79
+ meta_data = demp_video_input_that_has_transcript()
80
+
81
+ meta_data1 = demp_video_input_that_has_no_transcript()
82
+ data = meta_data1[1]
83
+ caption = data['transcript']
84
+ print(f'Generated caption is: "{caption}"')
85
+ frame = Image.open(data['extracted_frame_path'])
86
+ display(frame)
87
+ instruction = "Can you describe the image?"
88
+ ask_llvm(instruction, data['extracted_frame_path'])
89
+ #print(meta_data)
90
+
s7_store_in_rag.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mm_rag.embeddings.bridgetower_embeddings import (
2
+ BridgeTowerEmbeddings
3
+ )
4
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
5
+ import lancedb
6
+ import json
7
+ import os
8
+ from PIL import Image
9
+ from utility import load_json_file, display_retrieved_results
10
+ import pyarrow as pa
11
+
12
+ # declare host file
13
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
14
+ # declare table name
15
+ TBL_NAME = "test_tbl"
16
+ # initialize vectorstore
17
+ db = lancedb.connect(LANCEDB_HOST_FILE)
18
+ # initialize an BridgeTower embedder
19
+ embedder = BridgeTowerEmbeddings()
20
+
21
+
22
+ def return_top_k_most_similar_docs(max_docs=3):
23
+ # ask to return top 3 most similar documents
24
+ # Creating a LanceDB vector store
25
+ vectorstore = MultimodalLanceDB(
26
+ uri=LANCEDB_HOST_FILE,
27
+ embedding=embedder,
28
+ table_name=TBL_NAME)
29
+
30
+ # creating a retriever for the vector store
31
+ # search_type="similarity"
32
+ # declares that the type of search that the Retriever should perform
33
+ # is similarity search
34
+ # search_kwargs={"k": 1} means returning top-1 most similar document
35
+
36
+
37
+ retriever = vectorstore.as_retriever(
38
+ search_type='similarity',
39
+ search_kwargs={"k": max_docs})
40
+ query2 = (
41
+ "an astronaut's spacewalk "
42
+ "with an amazing view of the earth from space behind"
43
+ )
44
+ results2 = retriever.invoke(query2)
45
+ display_retrieved_results(results2)
46
+ query3 = "a group of astronauts"
47
+ results3 = retriever.invoke(query3)
48
+ display_retrieved_results(results3)
49
+
50
+
51
+ def open_table(TBL_NAME):
52
+ # open a connection to table TBL_NAME
53
+ tbl = db.open_table()
54
+
55
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
56
+ # display the first 3 rows of the table
57
+ tbl.to_pandas()[['text', 'image_path']].head(3)
58
+
59
+ def store_in_rag():
60
+
61
+ # load metadata files
62
+ vid1_metadata_path = './shared_data/videos/video1/metadatas.json'
63
+ vid2_metadata_path = './shared_data/videos/video2/metadatas.json'
64
+ vid1_metadata = load_json_file(vid1_metadata_path)
65
+ vid2_metadata = load_json_file(vid2_metadata_path)
66
+
67
+ # collect transcripts and image paths
68
+ vid1_trans = [vid['transcript'] for vid in vid1_metadata]
69
+ vid1_img_path = [vid['extracted_frame_path'] for vid in vid1_metadata]
70
+
71
+ vid2_trans = [vid['transcript'] for vid in vid2_metadata]
72
+ vid2_img_path = [vid['extracted_frame_path'] for vid in vid2_metadata]
73
+
74
+
75
+ # for video1, we pick n = 7
76
+ n = 7
77
+ updated_vid1_trans = [
78
+ ' '.join(vid1_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
79
+ ' '.join(vid1_trans[0 : i + int(n/2)]) for i in range(len(vid1_trans))
80
+ ]
81
+
82
+ # also need to update the updated transcripts in metadata
83
+ for i in range(len(updated_vid1_trans)):
84
+ vid1_metadata[i]['transcript'] = updated_vid1_trans[i]
85
+
86
+ # you can pass in mode="append"
87
+ # to add more entries to the vector store
88
+ # in case you want to start with a fresh vector store,
89
+ # you can pass in mode="overwrite" instead
90
+
91
+ _ = MultimodalLanceDB.from_text_image_pairs(
92
+ texts=updated_vid1_trans+vid2_trans,
93
+ image_paths=vid1_img_path+vid2_img_path,
94
+ embedding=embedder,
95
+ metadatas=vid1_metadata+vid2_metadata,
96
+ connection=db,
97
+ table_name=TBL_NAME,
98
+ mode="overwrite",
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ tbl = db.open_table(TBL_NAME)
103
+ print(f"There are {tbl.to_pandas().shape[0]} rows in the table")
104
+ #display the first 3 rows of the table
105
+ return_top_k_most_similar_docs()
upload_huggingface.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ api = HfApi()
3
+ api.upload_large_folder(
4
+ repo_id="88hours/hg_demo",
5
+ repo_type="space",
6
+ folder_path="./",
7
+
8
+ )
utility.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add your utilities or helper functions to this file.
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from io import StringIO, BytesIO
7
+ import textwrap
8
+ from typing import Iterator, TextIO, List, Dict, Any, Optional, Sequence, Union
9
+ from enum import auto, Enum
10
+ import base64
11
+ import glob
12
+ import requests
13
+ from tqdm import tqdm
14
+ from pytubefix import YouTube, Stream
15
+ import webvtt
16
+ from youtube_transcript_api import YouTubeTranscriptApi
17
+ from youtube_transcript_api.formatters import WebVTTFormatter
18
+ from predictionguard import PredictionGuard
19
+ import cv2
20
+ import json
21
+ import PIL
22
+ from ollama import chat
23
+ from ollama import ChatResponse
24
+ from PIL import Image
25
+ import dataclasses
26
+ import random
27
+ from datasets import load_dataset
28
+ from os import path as osp
29
+ from IPython.display import display
30
+ from langchain_core.prompt_values import PromptValue
31
+ from langchain_core.messages import (
32
+ MessageLikeRepresentation,
33
+ )
34
+
35
+ MultimodalModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation], Dict[str, Any]]
36
+
37
+ def get_from_dict_or_env(
38
+ data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
39
+ ) -> str:
40
+ """Get a value from a dictionary or an environment variable."""
41
+ if key in data and data[key]:
42
+ return data[key]
43
+ else:
44
+ return get_from_env(key, env_key, default=default)
45
+
46
+ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
47
+ """Get a value from a dictionary or an environment variable."""
48
+ if env_key in os.environ and os.environ[env_key]:
49
+ return os.environ[env_key]
50
+ else:
51
+ return default
52
+
53
+ def load_env():
54
+ _ = load_dotenv(find_dotenv())
55
+
56
+ def get_openai_api_key():
57
+ load_env()
58
+ openai_api_key = os.getenv("OPENAI_API_KEY")
59
+ return openai_api_key
60
+
61
+ def get_prediction_guard_api_key():
62
+ load_env()
63
+ PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
64
+ if PREDICTION_GUARD_API_KEY is None:
65
+ PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
66
+ return PREDICTION_GUARD_API_KEY
67
+
68
+ PREDICTION_GUARD_URL_ENDPOINT = os.getenv("DLAI_PREDICTION_GUARD_URL_ENDPOINT", "https://dl-itdc.predictionguard.com") ###"https://proxy-dl-itdc.predictionguard.com"
69
+
70
+ # prompt templates
71
+ templates = [
72
+ 'a picture of {}',
73
+ 'an image of {}',
74
+ 'a nice {}',
75
+ 'a beautiful {}',
76
+ ]
77
+
78
+ # function helps to prepare list image-text pairs from the first [test_size] data of a Huggingface dataset
79
+ def prepare_dataset_for_umap_visualization(hf_dataset, class_name, templates=templates, test_size=1000):
80
+ # load Huggingface dataset (download if needed)
81
+ dataset = load_dataset(hf_dataset, trust_remote_code=True)
82
+ # split dataset with specific test_size
83
+ train_test_dataset = dataset['train'].train_test_split(test_size=test_size)
84
+ # get the test dataset
85
+ test_dataset = train_test_dataset['test']
86
+ img_txt_pairs = []
87
+ for i in range(len(test_dataset)):
88
+ img_txt_pairs.append({
89
+ 'caption' : templates[random.randint(0, len(templates)-1)].format(class_name),
90
+ 'pil_img' : test_dataset[i]['image']
91
+ })
92
+ return img_txt_pairs
93
+
94
+
95
+ def download_video(video_url, path='/tmp/'):
96
+ print(f'Getting video information for {video_url}')
97
+ if not video_url.startswith('http'):
98
+ return os.path.join(path, video_url)
99
+
100
+ filepath = glob.glob(os.path.join(path, '*.mp4'))
101
+ if len(filepath) > 0:
102
+ print('Video already downloaded')
103
+ return filepath[0]
104
+
105
+ def progress_callback(stream: Stream, data_chunk: bytes, bytes_remaining: int) -> None:
106
+ pbar.update(len(data_chunk))
107
+
108
+ yt = YouTube(video_url, on_progress_callback=progress_callback)
109
+ stream = yt.streams.filter(progressive=True, file_extension='mp4', res='480p').desc().first()
110
+ if stream is None:
111
+ stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
112
+ if not os.path.exists(path):
113
+ os.makedirs(path)
114
+ filename = stream.default_filename.replace(' ', '_')
115
+ filepath = os.path.join(path, filename)
116
+
117
+ if not os.path.exists(filepath):
118
+ print('Downloading video from YouTube...')
119
+ pbar = tqdm(desc='Downloading video from YouTube', total=stream.filesize, unit="bytes")
120
+ stream.download(path, filename=filename)
121
+ pbar.close()
122
+ return filepath
123
+
124
+ def get_video_id_from_url(video_url):
125
+ """
126
+ Examples:
127
+ - http://youtu.be/SA2iWivDJiE
128
+ - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
129
+ - http://www.youtube.com/embed/SA2iWivDJiE
130
+ - http://www.youtube.com/v/SA2iWivDJiE?version=3&amp;hl=en_US
131
+ """
132
+ import urllib.parse
133
+ url = urllib.parse.urlparse(video_url)
134
+ if url.hostname == 'youtu.be':
135
+ return url.path[1:]
136
+ if url.hostname in ('www.youtube.com', 'youtube.com'):
137
+ if url.path == '/watch':
138
+ p = urllib.parse.parse_qs(url.query)
139
+ return p['v'][0]
140
+ if url.path[:7] == '/embed/':
141
+ return url.path.split('/')[2]
142
+ if url.path[:3] == '/v/':
143
+ return url.path.split('/')[2]
144
+
145
+ return video_url
146
+
147
+ # if this has transcript then download
148
+ def get_transcript_vtt(video_url, path='/tmp'):
149
+ video_id = get_video_id_from_url(video_url)
150
+ filepath = os.path.join(path,'captions.vtt')
151
+ if os.path.exists(filepath):
152
+ return filepath
153
+
154
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en-GB', 'en'])
155
+ formatter = WebVTTFormatter()
156
+ webvtt_formatted = formatter.format_transcript(transcript)
157
+
158
+ with open(filepath, 'w', encoding='utf-8') as webvtt_file:
159
+ webvtt_file.write(webvtt_formatted)
160
+ webvtt_file.close()
161
+
162
+ return filepath
163
+
164
+
165
+ # helper function for convert time in second to time format for .vtt or .srt file
166
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
167
+ assert seconds >= 0, "non-negative timestamp expected"
168
+ milliseconds = round(seconds * 1000.0)
169
+
170
+ hours = milliseconds // 3_600_000
171
+ milliseconds -= hours * 3_600_000
172
+
173
+ minutes = milliseconds // 60_000
174
+ milliseconds -= minutes * 60_000
175
+
176
+ seconds = milliseconds // 1_000
177
+ milliseconds -= seconds * 1_000
178
+
179
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
180
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
181
+
182
+ # a help function that helps to convert a specific time written as a string in format `webvtt` into a time in miliseconds
183
+ def str2time(strtime):
184
+ # strip character " if exists
185
+ strtime = strtime.strip('"')
186
+ # get hour, minute, second from time string
187
+ hrs, mins, seconds = [float(c) for c in strtime.split(':')]
188
+ # get the corresponding time as total seconds
189
+ total_seconds = hrs * 60**2 + mins * 60 + seconds
190
+ total_miliseconds = total_seconds * 1000
191
+ return total_miliseconds
192
+
193
+ def _processText(text: str, maxLineWidth=None):
194
+ if (maxLineWidth is None or maxLineWidth < 0):
195
+ return text
196
+
197
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
198
+ return '\n'.join(lines)
199
+
200
+ # Resizes a image and maintains aspect ratio
201
+ def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
202
+ # Grab the image size and initialize dimensions
203
+ dim = None
204
+ (h, w) = image.shape[:2]
205
+
206
+ # Return original image if no need to resize
207
+ if width is None and height is None:
208
+ return image
209
+
210
+ # We are resizing height if width is none
211
+ if width is None:
212
+ # Calculate the ratio of the height and construct the dimensions
213
+ r = height / float(h)
214
+ dim = (int(w * r), height)
215
+ # We are resizing width if height is none
216
+ else:
217
+ # Calculate the ratio of the width and construct the dimensions
218
+ r = width / float(w)
219
+ dim = (width, int(h * r))
220
+
221
+ # Return the resized image
222
+ return cv2.resize(image, dim, interpolation=inter)
223
+
224
+ # helper function to convert transcripts generated by whisper to .vtt file
225
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
226
+ print("WEBVTT\n", file=file)
227
+ for segment in transcript:
228
+ text = _processText(segment['text'], maxLineWidth).replace('-->', '->')
229
+
230
+ print(
231
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
232
+ f"{text}\n",
233
+ file=file,
234
+ flush=True,
235
+ )
236
+
237
+ # helper function to convert transcripts generated by whisper to .srt file
238
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
239
+ """
240
+ Write a transcript to a file in SRT format.
241
+ Example usage:
242
+ from pathlib import Path
243
+ from whisper.utils import write_srt
244
+ import requests
245
+ result = transcribe(model, audio_path, temperature=temperature, **args)
246
+ # save SRT
247
+ audio_basename = Path(audio_path).stem
248
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
249
+ write_srt(result["segments"], file=srt)
250
+ """
251
+ for i, segment in enumerate(transcript, start=1):
252
+ text = _processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
253
+
254
+ # write srt lines
255
+ print(
256
+ f"{i}\n"
257
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
258
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
259
+ f"{text}\n",
260
+ file=file,
261
+ flush=True,
262
+ )
263
+
264
+ def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int=-1) -> str:
265
+ segmentStream = StringIO()
266
+
267
+ if format == 'vtt':
268
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
269
+ elif format == 'srt':
270
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
271
+ else:
272
+ raise Exception("Unknown format " + format)
273
+
274
+ segmentStream.seek(0)
275
+ return segmentStream.read()
276
+
277
+ # encoding image at given path or PIL Image using base64
278
+ def encode_image(image_path_or_PIL_img):
279
+ if isinstance(image_path_or_PIL_img, PIL.Image.Image):
280
+ # this is a PIL image
281
+ buffered = BytesIO()
282
+ image_path_or_PIL_img.save(buffered, format="JPEG")
283
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
284
+ else:
285
+ # this is a image_path
286
+ with open(image_path_or_PIL_img, "rb") as image_file:
287
+ return base64.b64encode(image_file.read()).decode('utf-8')
288
+
289
+ # checking whether the given string is base64 or not
290
+ def isBase64(sb):
291
+ try:
292
+ if isinstance(sb, str):
293
+ # If there's any unicode here, an exception will be thrown and the function will return false
294
+ sb_bytes = bytes(sb, 'ascii')
295
+ elif isinstance(sb, bytes):
296
+ sb_bytes = sb
297
+ else:
298
+ raise ValueError("Argument must be string or bytes")
299
+ return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes
300
+ except Exception:
301
+ return False
302
+
303
+ def encode_image_from_path_or_url(image_path_or_url):
304
+ try:
305
+ # try to open the url to check valid url
306
+ f = urlopen(image_path_or_url)
307
+ # if this is an url
308
+ return base64.b64encode(requests.get(image_path_or_url).content).decode('utf-8')
309
+ except:
310
+ # this is a path to image
311
+ with open(image_path_or_url, "rb") as image_file:
312
+ return base64.b64encode(image_file.read()).decode('utf-8')
313
+
314
+ # helper function to compute the joint embedding of a prompt and a base64-encoded image through PredictionGuard
315
+ def bt_embedding_from_prediction_guard(prompt, base64_image):
316
+ # get PredictionGuard client
317
+ client = _getPredictionGuardClient()
318
+ message = {"text": prompt,}
319
+ if base64_image is not None and base64_image != "":
320
+ if not isBase64(base64_image):
321
+ raise TypeError("image input must be in base64 encoding!")
322
+ message['image'] = base64_image
323
+ response = client.embeddings.create(
324
+ model="bridgetower-large-itm-mlm-itc",
325
+ input=[message]
326
+ )
327
+ return response['data'][0]['embedding']
328
+
329
+
330
+ def load_json_file(file_path):
331
+ # Open the JSON file in read mode
332
+ with open(file_path, 'r') as file:
333
+ data = json.load(file)
334
+ return data
335
+
336
+ def display_retrieved_results(results):
337
+ print(f'There is/are {len(results)} retrieved result(s)')
338
+ print()
339
+ for i, res in enumerate(results):
340
+ print(f'The caption of the {str(i+1)}-th retrieved result is:\n"{results[i].page_content}"')
341
+ print()
342
+ print(results[i])
343
+ #display(Image.open(results[i].metadata['metadata']['extracted_frame_path']))
344
+ print("------------------------------------------------------------")
345
+
346
+ class SeparatorStyle(Enum):
347
+ """Different separator style."""
348
+ SINGLE = auto()
349
+
350
+ @dataclasses.dataclass
351
+ class Conversation:
352
+ """A class that keeps all conversation history"""
353
+ system: str
354
+ roles: List[str]
355
+ messages: List[List[str]]
356
+ map_roles: Dict[str, str]
357
+ version: str = "Unknown"
358
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
359
+ sep: str = "\n"
360
+
361
+ def _get_prompt_role(self, role):
362
+ if self.map_roles is not None and role in self.map_roles.keys():
363
+ return self.map_roles[role]
364
+ else:
365
+ return role
366
+
367
+ def _build_content_for_first_message_in_conversation(self, first_message: List[str]):
368
+ content = []
369
+ if len(first_message) != 2:
370
+ raise TypeError("First message in Conversation needs to include a prompt and a base64-enconded image!")
371
+
372
+ prompt, b64_image = first_message[0], first_message[1]
373
+
374
+ # handling prompt
375
+ if prompt is None:
376
+ raise TypeError("API does not support None prompt yet")
377
+ content.append({
378
+ "type": "text",
379
+ "text": prompt
380
+ })
381
+ if b64_image is None:
382
+ raise TypeError("API does not support text only conversation yet")
383
+
384
+ # handling image
385
+ if not isBase64(b64_image):
386
+ raise TypeError("Image in Conversation's first message must be stored under base64 encoding!")
387
+
388
+ content.append({
389
+ "type": "image_url",
390
+ "image_url": {
391
+ "url": b64_image,
392
+ }
393
+ })
394
+ return content
395
+
396
+ def _build_content_for_follow_up_messages_in_conversation(self, follow_up_message: List[str]):
397
+
398
+ if follow_up_message is not None and len(follow_up_message) > 1:
399
+ raise TypeError("Follow-up message in Conversation must not include an image!")
400
+
401
+ # handling text prompt
402
+ if follow_up_message is None or follow_up_message[0] is None:
403
+ raise TypeError("Follow-up message in Conversation must include exactly one text message")
404
+
405
+ text = follow_up_message[0]
406
+ return text
407
+
408
+ def get_message(self):
409
+ messages = self.messages
410
+ api_messages = []
411
+ for i, msg in enumerate(messages):
412
+ role, message_content = msg
413
+ if i == 0:
414
+ # get content for very first message in conversation
415
+ content = self._build_content_for_first_message_in_conversation(message_content)
416
+ else:
417
+ # get content for follow-up message in conversation
418
+ content = self._build_content_for_follow_up_messages_in_conversation(message_content)
419
+
420
+ api_messages.append({
421
+ "role": role,
422
+ "content": content,
423
+ })
424
+ return api_messages
425
+
426
+ # this method helps represent a multi-turn chat into as a single turn chat format
427
+ def serialize_messages(self):
428
+ messages = self.messages
429
+ ret = ""
430
+ if self.sep_style == SeparatorStyle.SINGLE:
431
+ if self.system is not None and self.system != "":
432
+ ret = self.system + self.sep
433
+ for i, (role, message) in enumerate(messages):
434
+ role = self._get_prompt_role(role)
435
+ if message:
436
+ if isinstance(message, List):
437
+ # get prompt only
438
+ message = message[0]
439
+ if i == 0:
440
+ # do not include role at the beginning
441
+ ret += message
442
+ else:
443
+ ret += role + ": " + message
444
+ if i < len(messages) - 1:
445
+ # avoid including sep at the end of serialized message
446
+ ret += self.sep
447
+ else:
448
+ ret += role + ":"
449
+ else:
450
+ raise ValueError(f"Invalid style: {self.sep_style}")
451
+
452
+ return ret
453
+
454
+ def append_message(self, role, message):
455
+ if len(self.messages) == 0:
456
+ # data verification for the very first message
457
+ assert role == self.roles[0], f"the very first message in conversation must be from role {self.roles[0]}"
458
+ assert len(message) == 2, f"the very first message in conversation must include both prompt and an image"
459
+ prompt, image = message[0], message[1]
460
+ assert prompt is not None, f"prompt must be not None"
461
+ assert isBase64(image), f"image must be under base64 encoding"
462
+ else:
463
+ # data verification for follow-up message
464
+ assert role in self.roles, f"the follow-up message must be from one of the roles {self.roles}"
465
+ assert len(message) == 1, f"the follow-up message must consist of one text message only, no image"
466
+
467
+ self.messages.append([role, message])
468
+
469
+ def copy(self):
470
+ return Conversation(
471
+ system=self.system,
472
+ roles=self.roles,
473
+ messages=[[x,y] for x, y in self.messages],
474
+ version=self.version,
475
+ map_roles=self.map_roles,
476
+ )
477
+
478
+ def dict(self):
479
+ return {
480
+ "system": self.system,
481
+ "roles": self.roles,
482
+ "messages": [[x, y[0] if len(y) == 1 else y] for x, y in self.messages],
483
+ "version": self.version,
484
+ }
485
+
486
+ prediction_guard_llava_conv = Conversation(
487
+ system="",
488
+ roles=("user", "assistant"),
489
+ messages=[],
490
+ version="Prediction Guard LLaVA enpoint Conversation v0",
491
+ sep_style=SeparatorStyle.SINGLE,
492
+ map_roles={
493
+ "user": "USER",
494
+ "assistant": "ASSISTANT"
495
+ }
496
+ )
497
+
498
+ # get PredictionGuard Client
499
+ def _getPredictionGuardClient():
500
+ PREDICTION_GUARD_API_KEY = get_prediction_guard_api_key()
501
+ client = PredictionGuard(
502
+ api_key=PREDICTION_GUARD_API_KEY,
503
+ url=PREDICTION_GUARD_URL_ENDPOINT,
504
+ )
505
+ return client
506
+
507
+ # helper function to call chat completion endpoint of PredictionGuard given a prompt and an image
508
+ def lvlm_inference(prompt, image, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
509
+ # prepare conversation
510
+ conversation = prediction_guard_llava_conv.copy()
511
+ conversation.append_message(conversation.roles[0], [prompt, image])
512
+ return lvlm_inference_with_conversation(conversation, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
513
+
514
+
515
+
516
+ def lvlm_inference_with_conversation(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
517
+ # get PredictionGuard client
518
+ client = _getPredictionGuardClient()
519
+ # get message from conversation
520
+ messages = conversation.get_message()
521
+ # call chat completion endpoint at Grediction Guard
522
+ response = client.chat.completions.create(
523
+ model="llava-1.5-7b-hf",
524
+ messages=messages,
525
+ max_tokens=max_tokens,
526
+ temperature=temperature,
527
+ top_p=top_p,
528
+ top_k=top_k,
529
+ )
530
+ return response['choices'][-1]['message']['content']
531
+
532
+ def lvlm_inference_with_ollama(conversation, max_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.1, top_k: int = 10):
533
+
534
+
535
+
536
+ # Send the request to the local Ollama server
537
+ #response = requests.post("http://localhost:8000/api/v1/completions", json=payload)
538
+
539
+ stream = chat(
540
+ model="llava-1.5-7b-hf",
541
+ messages= conversation,
542
+ stream=True,
543
+ temperature=temperature,
544
+ max_tokens=max_tokens,
545
+ top_p=top_p,
546
+ top_k=top_k
547
+ )
548
+
549
+ response_data = ''
550
+ for chunk in stream:
551
+ response_data += chunk['message']['content']
552
+
553
+ return response_data
554
+
555
+ # function `extract_and_save_frames_and_metadata``:
556
+ # receives as input a video and its transcript
557
+ # does extracting and saving frames and their metadatas
558
+ # returns the extracted metadatas
559
+ def extract_and_save_frames_and_metadata(
560
+ path_to_video,
561
+ path_to_transcript,
562
+ path_to_save_extracted_frames,
563
+ path_to_save_metadatas):
564
+
565
+ # metadatas will store the metadata of all extracted frames
566
+ metadatas = []
567
+
568
+ # load video using cv2
569
+ video = cv2.VideoCapture(path_to_video)
570
+ # load transcript using webvtt
571
+ trans = webvtt.read(path_to_transcript)
572
+
573
+ # iterate transcript file
574
+ # for each video segment specified in the transcript file
575
+ for idx, transcript in enumerate(trans):
576
+ # get the start time and end time in seconds
577
+ start_time_ms = str2time(transcript.start)
578
+ end_time_ms = str2time(transcript.end)
579
+ # get the time in ms exactly
580
+ # in the middle of start time and end time
581
+ mid_time_ms = (end_time_ms + start_time_ms) / 2
582
+ # get the transcript, remove the next-line symbol
583
+ text = transcript.text.replace("\n", ' ')
584
+ # get frame at the middle time
585
+ video.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
586
+ success, frame = video.read()
587
+ if success:
588
+ # if the frame is extracted successfully, resize it
589
+ image = maintain_aspect_ratio_resize(frame, height=350)
590
+ # save frame as JPEG file
591
+ img_fname = f'frame_{idx}.jpg'
592
+ img_fpath = osp.join(
593
+ path_to_save_extracted_frames, img_fname
594
+ )
595
+ cv2.imwrite(img_fpath, image)
596
+
597
+ # prepare the metadata
598
+ metadata = {
599
+ 'extracted_frame_path': img_fpath,
600
+ 'transcript': text,
601
+ 'video_segment_id': idx,
602
+ 'video_path': path_to_video,
603
+ 'mid_time_ms': mid_time_ms,
604
+ }
605
+ metadatas.append(metadata)
606
+
607
+ else:
608
+ print(f"ERROR! Cannot extract frame: idx = {idx}")
609
+
610
+ # save metadata of all extracted frames
611
+ fn = osp.join(path_to_save_metadatas, 'metadatas.json')
612
+ with open(fn, 'w') as outfile:
613
+ json.dump(metadatas, outfile)
614
+ return metadatas
615
+
616
+ def extract_meta_data(vid_dir, vid_filepath, vid_transcript_filepath):
617
+ # output paths to save extracted frames and their metadata
618
+ extracted_frames_path = osp.join(vid_dir, 'extracted_frame')
619
+ metadatas_path = vid_dir
620
+
621
+ # create these output folders if not existing
622
+ Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
623
+ Path(metadatas_path).mkdir(parents=True, exist_ok=True)
624
+
625
+ # call the function to extract frames and metadatas
626
+ metadatas = extract_and_save_frames_and_metadata(
627
+ vid_filepath,
628
+ vid_transcript_filepath,
629
+ extracted_frames_path,
630
+ metadatas_path,
631
+ )
632
+ return metadatas
633
+
634
+ # function extract_and_save_frames_and_metadata_with_fps
635
+ # receives as input a video
636
+ # does extracting and saving frames and their metadatas
637
+ # returns the extracted metadatas
638
+ def extract_and_save_frames_and_metadata_with_fps(
639
+ lvlm_prompt,
640
+ path_to_video,
641
+ path_to_save_extracted_frames,
642
+ path_to_save_metadatas,
643
+ num_of_extracted_frames_per_second=1):
644
+
645
+ # metadatas will store the metadata of all extracted frames
646
+ metadatas = []
647
+
648
+ # load video using cv2
649
+ video = cv2.VideoCapture(path_to_video)
650
+
651
+ # Get the frames per second
652
+ fps = video.get(cv2.CAP_PROP_FPS)
653
+ # Get hop = the number of frames pass before a frame is extracted
654
+ hop = round(fps / num_of_extracted_frames_per_second)
655
+ curr_frame = 0
656
+ idx = -1
657
+ while(True):
658
+ # iterate all frames
659
+ ret, frame = video.read()
660
+ if not ret:
661
+ break
662
+ if curr_frame % hop == 0:
663
+ idx = idx + 1
664
+
665
+ # if the frame is extracted successfully, resize it
666
+ image = maintain_aspect_ratio_resize(frame, height=350)
667
+ # save frame as JPEG file
668
+ img_fname = f'frame_{idx}.jpg'
669
+ img_fpath = osp.join(
670
+ path_to_save_extracted_frames,
671
+ img_fname
672
+ )
673
+ cv2.imwrite(img_fpath, image)
674
+
675
+ # generate caption using lvlm_inference
676
+ b64_image = encode_image(img_fpath)
677
+ caption = lvlm_inference(lvlm_prompt, b64_image)
678
+
679
+ # prepare the metadata
680
+ metadata = {
681
+ 'extracted_frame_path': img_fpath,
682
+ 'transcript': caption,
683
+ 'video_segment_id': idx,
684
+ 'video_path': path_to_video,
685
+ }
686
+ metadatas.append(metadata)
687
+ curr_frame += 1
688
+
689
+ # save metadata of all extracted frames
690
+ metadatas_path = osp.join(path_to_save_metadatas,'metadatas.json')
691
+ with open(metadatas_path, 'w') as outfile:
692
+ json.dump(metadatas, outfile)
693
+ return metadatas