|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import queue |
|
|
import threading |
|
|
import traceback |
|
|
import time |
|
|
import gc |
|
|
from typing import Any, Dict, List |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
import vertexai |
|
|
from vertexai.generative_models import GenerativeModel, Part |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VideoEntry: |
|
|
mp4_path: str |
|
|
|
|
|
youtube_key_segment: str = None |
|
|
duration: float = None |
|
|
fps: float = None |
|
|
height: int = None |
|
|
width: int = None |
|
|
n_frames: int = None |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CaptionResult: |
|
|
mp4_path: str |
|
|
caption: str |
|
|
|
|
|
youtube_key_segment: str = None |
|
|
duration: float = None |
|
|
fps: float = None |
|
|
height: int = None |
|
|
width: int = None |
|
|
n_frames: int = None |
|
|
|
|
|
|
|
|
class GeminiCaptionProcessor: |
|
|
def __init__(self, output_file: str, num_workers: int = 12): |
|
|
self.output_file = output_file |
|
|
self.num_workers = num_workers |
|
|
self.entry_queue = queue.Queue() |
|
|
self.results_queue = queue.Queue() |
|
|
self.workers = [] |
|
|
self.success_count = 0 |
|
|
self.fail_count = 0 |
|
|
self.start_time = None |
|
|
self.end_time = None |
|
|
|
|
|
|
|
|
PROJECT_ID = "fas-dev-kempner-2b75" |
|
|
model_index = 0 |
|
|
LOCATION = ["us-central1", "us-east5"][model_index] |
|
|
vertexai.init(project=PROJECT_ID, location=LOCATION) |
|
|
MODEL_NAME = ["gemini-2.0-flash-001", "gemini-1.5-flash-002"][ |
|
|
model_index |
|
|
] |
|
|
self.model = GenerativeModel(model_name=MODEL_NAME) |
|
|
print(f"Using model: {MODEL_NAME}") |
|
|
|
|
|
self.prompt = ( |
|
|
"Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. " |
|
|
"Focus on key actions, interactions, and movements. Include camera movements. " |
|
|
"Keep the summary brief and clear. " |
|
|
"Only include information that is certain, and avoid speculation or assumptions." |
|
|
"In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?" |
|
|
) |
|
|
|
|
|
self.count_lock = threading.Lock() |
|
|
|
|
|
self.optional_keys = [ |
|
|
"duration", |
|
|
"fps", |
|
|
"height", |
|
|
"width", |
|
|
"n_frames", |
|
|
"youtube_key_segment", |
|
|
] |
|
|
|
|
|
def process_entries(self, records: List[Dict[str, Any]]): |
|
|
self.start_time = time.time() |
|
|
|
|
|
for _ in range(self.num_workers): |
|
|
worker = threading.Thread(target=self._worker_process, daemon=True) |
|
|
worker.start() |
|
|
self.workers.append(worker) |
|
|
|
|
|
|
|
|
to_process_count = 0 |
|
|
for data in records: |
|
|
entry = VideoEntry( |
|
|
mp4_path=data["video_path"], |
|
|
) |
|
|
|
|
|
for key in self.optional_keys: |
|
|
if key in data: |
|
|
entry.__dict__[key] = data[key] |
|
|
self.entry_queue.put(entry) |
|
|
to_process_count += 1 |
|
|
|
|
|
if to_process_count == 0: |
|
|
print("No new entries to process. All done!") |
|
|
|
|
|
for _ in range(self.num_workers): |
|
|
self.entry_queue.put(None) |
|
|
return |
|
|
|
|
|
|
|
|
for _ in range(self.num_workers): |
|
|
self.entry_queue.put(None) |
|
|
|
|
|
|
|
|
for worker in self.workers: |
|
|
worker.join() |
|
|
|
|
|
|
|
|
results = [] |
|
|
while not self.results_queue.empty(): |
|
|
result = self.results_queue.get() |
|
|
|
|
|
if not result.caption.startswith("Error"): |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
with open(self.output_file, "a", encoding="utf-8") as f: |
|
|
for result in results: |
|
|
obj = {"video_path": result.mp4_path, "caption": result.caption} |
|
|
for key in self.optional_keys: |
|
|
if key in result.__dict__ and result.__dict__[key] is not None: |
|
|
obj[key] = result.__dict__[key] |
|
|
f.write(json.dumps(obj) + "\n") |
|
|
|
|
|
self.end_time = time.time() |
|
|
total_time = self.end_time - self.start_time |
|
|
print(f"Processed {len(results)} entries successfully.") |
|
|
print(f"Failed on {self.fail_count} entries.") |
|
|
print(f"Total time: {total_time:.2f} seconds.") |
|
|
if to_process_count > 0: |
|
|
print(f"Throughput: {to_process_count / total_time:.2f} videos/second.") |
|
|
print(f"Output file: {self.output_file}") |
|
|
|
|
|
def _read_video_file(self, file_path): |
|
|
"""Read video file and convert it to base64.""" |
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f"Video file not found: {file_path}") |
|
|
with open(file_path, "rb") as video_file: |
|
|
return base64.b64encode(video_file.read()).decode("utf-8") |
|
|
|
|
|
def get_gemini_caption(self, video_path: str) -> str: |
|
|
"""Generate a caption for a single video using Gemini Flash.""" |
|
|
video_data = self._read_video_file(video_path) |
|
|
video_part = Part.from_data(data=video_data, mime_type="video/mp4") |
|
|
try: |
|
|
response = self.model.generate_content( |
|
|
[video_part, self.prompt], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stream=False, |
|
|
) |
|
|
return response.text |
|
|
except Exception as e: |
|
|
print(f"Error from Gemini API: {e}") |
|
|
return f"Error from Gemini API: {e}" |
|
|
|
|
|
def _process_single_entry(self, entry: VideoEntry) -> CaptionResult: |
|
|
caption = self.get_gemini_caption(entry.mp4_path) |
|
|
|
|
|
ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption) |
|
|
for key in self.optional_keys: |
|
|
if key in entry.__dict__ and entry.__dict__[key] is not None: |
|
|
ret_result.__dict__[key] = entry.__dict__[key] |
|
|
return ret_result |
|
|
|
|
|
def _worker_process(self): |
|
|
while True: |
|
|
entry = self.entry_queue.get() |
|
|
if entry is None: |
|
|
break |
|
|
if self.entry_queue.qsize() % 100 == 0: |
|
|
print( |
|
|
f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue." |
|
|
) |
|
|
gc_s_time = time.time() |
|
|
num_gc = gc.collect() |
|
|
gc_e_time = time.time() |
|
|
print( |
|
|
f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects" |
|
|
) |
|
|
try: |
|
|
result = self._process_single_entry(entry) |
|
|
|
|
|
if not result.caption.startswith("Error"): |
|
|
with self.count_lock: |
|
|
self.success_count += 1 |
|
|
self.results_queue.put(result) |
|
|
else: |
|
|
with self.count_lock: |
|
|
self.fail_count += 1 |
|
|
print(f"Skipping {entry.mp4_path} due to error in captioning.") |
|
|
except Exception as e: |
|
|
with self.count_lock: |
|
|
self.fail_count += 1 |
|
|
print(f"Error processing {entry.mp4_path}: {str(e)}") |
|
|
traceback.print_exc() |
|
|
finally: |
|
|
self.entry_queue.task_done() |
|
|
|