Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import base64 | |
| import queue | |
| import threading | |
| import traceback | |
| import time | |
| import gc | |
| from typing import Any, Dict, List | |
| from dataclasses import dataclass | |
| # Gemini / Vertex AI imports | |
| import vertexai | |
| from vertexai.generative_models import GenerativeModel, Part | |
| class VideoEntry: | |
| mp4_path: str | |
| # optional keys below: | |
| youtube_key_segment: str = None | |
| duration: float = None | |
| fps: float = None | |
| height: int = None | |
| width: int = None | |
| n_frames: int = None | |
| # Add other metadata fields as needed | |
| class CaptionResult: | |
| mp4_path: str | |
| caption: str | |
| # optional keys below: | |
| youtube_key_segment: str = None | |
| duration: float = None | |
| fps: float = None | |
| height: int = None | |
| width: int = None | |
| n_frames: int = None | |
| class GeminiCaptionProcessor: | |
| def __init__(self, output_file: str, num_workers: int = 12): | |
| self.output_file = output_file | |
| self.num_workers = num_workers | |
| self.entry_queue = queue.Queue() | |
| self.results_queue = queue.Queue() | |
| self.workers = [] | |
| self.success_count = 0 | |
| self.fail_count = 0 | |
| self.start_time = None | |
| self.end_time = None | |
| # Initialize Vertex AI | |
| PROJECT_ID = "fas-dev-kempner-2b75" | |
| model_index = 0 | |
| LOCATION = ["us-central1", "us-east5"][model_index] | |
| vertexai.init(project=PROJECT_ID, location=LOCATION) | |
| MODEL_NAME = ["gemini-2.0-flash-001", "gemini-1.5-flash-002"][ | |
| model_index | |
| ] # "gemini-2.0-flash-001" or "gemini-1.5-flash-002" | |
| self.model = GenerativeModel(model_name=MODEL_NAME) | |
| print(f"Using model: {MODEL_NAME}") | |
| self.prompt = ( | |
| "Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. " | |
| "Focus on key actions, interactions, and movements. Include camera movements. " | |
| "Keep the summary brief and clear. " | |
| "Only include information that is certain, and avoid speculation or assumptions." | |
| "In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?" | |
| ) | |
| # Lock for updating success and fail counts | |
| self.count_lock = threading.Lock() | |
| self.optional_keys = [ | |
| "duration", | |
| "fps", | |
| "height", | |
| "width", | |
| "n_frames", | |
| "youtube_key_segment", | |
| ] | |
| def process_entries(self, records: List[Dict[str, Any]]): | |
| self.start_time = time.time() | |
| # Start worker threads | |
| for _ in range(self.num_workers): | |
| worker = threading.Thread(target=self._worker_process, daemon=True) | |
| worker.start() | |
| self.workers.append(worker) | |
| # Producer: read input lines and put them into the queue | |
| to_process_count = 0 | |
| for data in records: | |
| entry = VideoEntry( | |
| mp4_path=data["video_path"], | |
| ) | |
| # add optional keys to entry: | |
| for key in self.optional_keys: | |
| if key in data: | |
| entry.__dict__[key] = data[key] | |
| self.entry_queue.put(entry) | |
| to_process_count += 1 | |
| if to_process_count == 0: | |
| print("No new entries to process. All done!") | |
| # Even if none, still send sentinels to avoid blocking | |
| for _ in range(self.num_workers): | |
| self.entry_queue.put(None) | |
| return | |
| # Add sentinel values to signal workers to stop | |
| for _ in range(self.num_workers): | |
| self.entry_queue.put(None) | |
| # Wait for all workers to finish | |
| for worker in self.workers: | |
| worker.join() | |
| # Collect results | |
| results = [] | |
| while not self.results_queue.empty(): | |
| result = self.results_queue.get() | |
| # Only append results that aren't error messages | |
| if not result.caption.startswith("Error"): | |
| results.append(result) | |
| # Append results to output file | |
| with open(self.output_file, "a", encoding="utf-8") as f: | |
| for result in results: | |
| obj = {"video_path": result.mp4_path, "caption": result.caption} | |
| for key in self.optional_keys: | |
| if key in result.__dict__ and result.__dict__[key] is not None: | |
| obj[key] = result.__dict__[key] | |
| f.write(json.dumps(obj) + "\n") | |
| self.end_time = time.time() | |
| total_time = self.end_time - self.start_time | |
| print(f"Processed {len(results)} entries successfully.") | |
| print(f"Failed on {self.fail_count} entries.") | |
| print(f"Total time: {total_time:.2f} seconds.") | |
| if to_process_count > 0: | |
| print(f"Throughput: {to_process_count / total_time:.2f} videos/second.") | |
| print(f"Output file: {self.output_file}") | |
| def _read_video_file(self, file_path): | |
| """Read video file and convert it to base64.""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Video file not found: {file_path}") | |
| with open(file_path, "rb") as video_file: | |
| return base64.b64encode(video_file.read()).decode("utf-8") | |
| def get_gemini_caption(self, video_path: str) -> str: | |
| """Generate a caption for a single video using Gemini Flash.""" | |
| video_data = self._read_video_file(video_path) | |
| video_part = Part.from_data(data=video_data, mime_type="video/mp4") | |
| try: | |
| response = self.model.generate_content( | |
| [video_part, self.prompt], | |
| # generation_config={ | |
| # "max_output_tokens": 1024, | |
| # "temperature": 0.4 | |
| # }, | |
| stream=False, | |
| ) | |
| return response.text | |
| except Exception as e: | |
| print(f"Error from Gemini API: {e}") | |
| return f"Error from Gemini API: {e}" | |
| def _process_single_entry(self, entry: VideoEntry) -> CaptionResult: | |
| caption = self.get_gemini_caption(entry.mp4_path) | |
| ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption) | |
| for key in self.optional_keys: | |
| if key in entry.__dict__ and entry.__dict__[key] is not None: | |
| ret_result.__dict__[key] = entry.__dict__[key] | |
| return ret_result | |
| def _worker_process(self): | |
| while True: | |
| entry = self.entry_queue.get() | |
| if entry is None: # Check for sentinel value | |
| break | |
| if self.entry_queue.qsize() % 100 == 0: | |
| print( | |
| f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue." | |
| ) | |
| gc_s_time = time.time() | |
| num_gc = gc.collect() | |
| gc_e_time = time.time() | |
| print( | |
| f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects" | |
| ) | |
| try: | |
| result = self._process_single_entry(entry) | |
| # Check if result is error. If not, add to results_queue. | |
| if not result.caption.startswith("Error"): | |
| with self.count_lock: | |
| self.success_count += 1 | |
| self.results_queue.put(result) | |
| else: | |
| with self.count_lock: | |
| self.fail_count += 1 | |
| print(f"Skipping {entry.mp4_path} due to error in captioning.") | |
| except Exception as e: | |
| with self.count_lock: | |
| self.fail_count += 1 | |
| print(f"Error processing {entry.mp4_path}: {str(e)}") | |
| traceback.print_exc() | |
| finally: | |
| self.entry_queue.task_done() | |