LVP / utils /gemini_utils.py
kiwhansong's picture
add demo
142a1ac
import os
import json
import base64
import queue
import threading
import traceback
import time
import gc
from typing import Any, Dict, List
from dataclasses import dataclass
# Gemini / Vertex AI imports
import vertexai
from vertexai.generative_models import GenerativeModel, Part
@dataclass
class VideoEntry:
mp4_path: str
# optional keys below:
youtube_key_segment: str = None
duration: float = None
fps: float = None
height: int = None
width: int = None
n_frames: int = None
# Add other metadata fields as needed
@dataclass
class CaptionResult:
mp4_path: str
caption: str
# optional keys below:
youtube_key_segment: str = None
duration: float = None
fps: float = None
height: int = None
width: int = None
n_frames: int = None
class GeminiCaptionProcessor:
def __init__(self, output_file: str, num_workers: int = 12):
self.output_file = output_file
self.num_workers = num_workers
self.entry_queue = queue.Queue()
self.results_queue = queue.Queue()
self.workers = []
self.success_count = 0
self.fail_count = 0
self.start_time = None
self.end_time = None
# Initialize Vertex AI
PROJECT_ID = "fas-dev-kempner-2b75"
model_index = 0
LOCATION = ["us-central1", "us-east5"][model_index]
vertexai.init(project=PROJECT_ID, location=LOCATION)
MODEL_NAME = ["gemini-2.0-flash-001", "gemini-1.5-flash-002"][
model_index
] # "gemini-2.0-flash-001" or "gemini-1.5-flash-002"
self.model = GenerativeModel(model_name=MODEL_NAME)
print(f"Using model: {MODEL_NAME}")
self.prompt = (
"Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. "
"Focus on key actions, interactions, and movements. Include camera movements. "
"Keep the summary brief and clear. "
"Only include information that is certain, and avoid speculation or assumptions."
"In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?"
)
# Lock for updating success and fail counts
self.count_lock = threading.Lock()
self.optional_keys = [
"duration",
"fps",
"height",
"width",
"n_frames",
"youtube_key_segment",
]
def process_entries(self, records: List[Dict[str, Any]]):
self.start_time = time.time()
# Start worker threads
for _ in range(self.num_workers):
worker = threading.Thread(target=self._worker_process, daemon=True)
worker.start()
self.workers.append(worker)
# Producer: read input lines and put them into the queue
to_process_count = 0
for data in records:
entry = VideoEntry(
mp4_path=data["video_path"],
)
# add optional keys to entry:
for key in self.optional_keys:
if key in data:
entry.__dict__[key] = data[key]
self.entry_queue.put(entry)
to_process_count += 1
if to_process_count == 0:
print("No new entries to process. All done!")
# Even if none, still send sentinels to avoid blocking
for _ in range(self.num_workers):
self.entry_queue.put(None)
return
# Add sentinel values to signal workers to stop
for _ in range(self.num_workers):
self.entry_queue.put(None)
# Wait for all workers to finish
for worker in self.workers:
worker.join()
# Collect results
results = []
while not self.results_queue.empty():
result = self.results_queue.get()
# Only append results that aren't error messages
if not result.caption.startswith("Error"):
results.append(result)
# Append results to output file
with open(self.output_file, "a", encoding="utf-8") as f:
for result in results:
obj = {"video_path": result.mp4_path, "caption": result.caption}
for key in self.optional_keys:
if key in result.__dict__ and result.__dict__[key] is not None:
obj[key] = result.__dict__[key]
f.write(json.dumps(obj) + "\n")
self.end_time = time.time()
total_time = self.end_time - self.start_time
print(f"Processed {len(results)} entries successfully.")
print(f"Failed on {self.fail_count} entries.")
print(f"Total time: {total_time:.2f} seconds.")
if to_process_count > 0:
print(f"Throughput: {to_process_count / total_time:.2f} videos/second.")
print(f"Output file: {self.output_file}")
def _read_video_file(self, file_path):
"""Read video file and convert it to base64."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Video file not found: {file_path}")
with open(file_path, "rb") as video_file:
return base64.b64encode(video_file.read()).decode("utf-8")
def get_gemini_caption(self, video_path: str) -> str:
"""Generate a caption for a single video using Gemini Flash."""
video_data = self._read_video_file(video_path)
video_part = Part.from_data(data=video_data, mime_type="video/mp4")
try:
response = self.model.generate_content(
[video_part, self.prompt],
# generation_config={
# "max_output_tokens": 1024,
# "temperature": 0.4
# },
stream=False,
)
return response.text
except Exception as e:
print(f"Error from Gemini API: {e}")
return f"Error from Gemini API: {e}"
def _process_single_entry(self, entry: VideoEntry) -> CaptionResult:
caption = self.get_gemini_caption(entry.mp4_path)
ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption)
for key in self.optional_keys:
if key in entry.__dict__ and entry.__dict__[key] is not None:
ret_result.__dict__[key] = entry.__dict__[key]
return ret_result
def _worker_process(self):
while True:
entry = self.entry_queue.get()
if entry is None: # Check for sentinel value
break
if self.entry_queue.qsize() % 100 == 0:
print(
f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue."
)
gc_s_time = time.time()
num_gc = gc.collect()
gc_e_time = time.time()
print(
f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects"
)
try:
result = self._process_single_entry(entry)
# Check if result is error. If not, add to results_queue.
if not result.caption.startswith("Error"):
with self.count_lock:
self.success_count += 1
self.results_queue.put(result)
else:
with self.count_lock:
self.fail_count += 1
print(f"Skipping {entry.mp4_path} due to error in captioning.")
except Exception as e:
with self.count_lock:
self.fail_count += 1
print(f"Error processing {entry.mp4_path}: {str(e)}")
traceback.print_exc()
finally:
self.entry_queue.task_done()