Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,931 Bytes
142a1ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import json
import base64
import queue
import threading
import traceback
import time
import gc
from typing import Any, Dict, List
from dataclasses import dataclass
# Gemini / Vertex AI imports
import vertexai
from vertexai.generative_models import GenerativeModel, Part
@dataclass
class VideoEntry:
mp4_path: str
# optional keys below:
youtube_key_segment: str = None
duration: float = None
fps: float = None
height: int = None
width: int = None
n_frames: int = None
# Add other metadata fields as needed
@dataclass
class CaptionResult:
mp4_path: str
caption: str
# optional keys below:
youtube_key_segment: str = None
duration: float = None
fps: float = None
height: int = None
width: int = None
n_frames: int = None
class GeminiCaptionProcessor:
def __init__(self, output_file: str, num_workers: int = 12):
self.output_file = output_file
self.num_workers = num_workers
self.entry_queue = queue.Queue()
self.results_queue = queue.Queue()
self.workers = []
self.success_count = 0
self.fail_count = 0
self.start_time = None
self.end_time = None
# Initialize Vertex AI
PROJECT_ID = "fas-dev-kempner-2b75"
model_index = 0
LOCATION = ["us-central1", "us-east5"][model_index]
vertexai.init(project=PROJECT_ID, location=LOCATION)
MODEL_NAME = ["gemini-2.0-flash-001", "gemini-1.5-flash-002"][
model_index
] # "gemini-2.0-flash-001" or "gemini-1.5-flash-002"
self.model = GenerativeModel(model_name=MODEL_NAME)
print(f"Using model: {MODEL_NAME}")
self.prompt = (
"Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. "
"Focus on key actions, interactions, and movements. Include camera movements. "
"Keep the summary brief and clear. "
"Only include information that is certain, and avoid speculation or assumptions."
"In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?"
)
# Lock for updating success and fail counts
self.count_lock = threading.Lock()
self.optional_keys = [
"duration",
"fps",
"height",
"width",
"n_frames",
"youtube_key_segment",
]
def process_entries(self, records: List[Dict[str, Any]]):
self.start_time = time.time()
# Start worker threads
for _ in range(self.num_workers):
worker = threading.Thread(target=self._worker_process, daemon=True)
worker.start()
self.workers.append(worker)
# Producer: read input lines and put them into the queue
to_process_count = 0
for data in records:
entry = VideoEntry(
mp4_path=data["video_path"],
)
# add optional keys to entry:
for key in self.optional_keys:
if key in data:
entry.__dict__[key] = data[key]
self.entry_queue.put(entry)
to_process_count += 1
if to_process_count == 0:
print("No new entries to process. All done!")
# Even if none, still send sentinels to avoid blocking
for _ in range(self.num_workers):
self.entry_queue.put(None)
return
# Add sentinel values to signal workers to stop
for _ in range(self.num_workers):
self.entry_queue.put(None)
# Wait for all workers to finish
for worker in self.workers:
worker.join()
# Collect results
results = []
while not self.results_queue.empty():
result = self.results_queue.get()
# Only append results that aren't error messages
if not result.caption.startswith("Error"):
results.append(result)
# Append results to output file
with open(self.output_file, "a", encoding="utf-8") as f:
for result in results:
obj = {"video_path": result.mp4_path, "caption": result.caption}
for key in self.optional_keys:
if key in result.__dict__ and result.__dict__[key] is not None:
obj[key] = result.__dict__[key]
f.write(json.dumps(obj) + "\n")
self.end_time = time.time()
total_time = self.end_time - self.start_time
print(f"Processed {len(results)} entries successfully.")
print(f"Failed on {self.fail_count} entries.")
print(f"Total time: {total_time:.2f} seconds.")
if to_process_count > 0:
print(f"Throughput: {to_process_count / total_time:.2f} videos/second.")
print(f"Output file: {self.output_file}")
def _read_video_file(self, file_path):
"""Read video file and convert it to base64."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Video file not found: {file_path}")
with open(file_path, "rb") as video_file:
return base64.b64encode(video_file.read()).decode("utf-8")
def get_gemini_caption(self, video_path: str) -> str:
"""Generate a caption for a single video using Gemini Flash."""
video_data = self._read_video_file(video_path)
video_part = Part.from_data(data=video_data, mime_type="video/mp4")
try:
response = self.model.generate_content(
[video_part, self.prompt],
# generation_config={
# "max_output_tokens": 1024,
# "temperature": 0.4
# },
stream=False,
)
return response.text
except Exception as e:
print(f"Error from Gemini API: {e}")
return f"Error from Gemini API: {e}"
def _process_single_entry(self, entry: VideoEntry) -> CaptionResult:
caption = self.get_gemini_caption(entry.mp4_path)
ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption)
for key in self.optional_keys:
if key in entry.__dict__ and entry.__dict__[key] is not None:
ret_result.__dict__[key] = entry.__dict__[key]
return ret_result
def _worker_process(self):
while True:
entry = self.entry_queue.get()
if entry is None: # Check for sentinel value
break
if self.entry_queue.qsize() % 100 == 0:
print(
f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue."
)
gc_s_time = time.time()
num_gc = gc.collect()
gc_e_time = time.time()
print(
f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects"
)
try:
result = self._process_single_entry(entry)
# Check if result is error. If not, add to results_queue.
if not result.caption.startswith("Error"):
with self.count_lock:
self.success_count += 1
self.results_queue.put(result)
else:
with self.count_lock:
self.fail_count += 1
print(f"Skipping {entry.mp4_path} due to error in captioning.")
except Exception as e:
with self.count_lock:
self.fail_count += 1
print(f"Error processing {entry.mp4_path}: {str(e)}")
traceback.print_exc()
finally:
self.entry_queue.task_done()
|