File size: 7,931 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import json
import base64
import queue
import threading
import traceback
import time
import gc
from typing import Any, Dict, List
from dataclasses import dataclass

# Gemini / Vertex AI imports
import vertexai
from vertexai.generative_models import GenerativeModel, Part


@dataclass
class VideoEntry:
    mp4_path: str
    # optional keys below:
    youtube_key_segment: str = None
    duration: float = None
    fps: float = None
    height: int = None
    width: int = None
    n_frames: int = None
    # Add other metadata fields as needed


@dataclass
class CaptionResult:
    mp4_path: str
    caption: str
    # optional keys below:
    youtube_key_segment: str = None
    duration: float = None
    fps: float = None
    height: int = None
    width: int = None
    n_frames: int = None


class GeminiCaptionProcessor:
    def __init__(self, output_file: str, num_workers: int = 12):
        self.output_file = output_file
        self.num_workers = num_workers
        self.entry_queue = queue.Queue()
        self.results_queue = queue.Queue()
        self.workers = []
        self.success_count = 0
        self.fail_count = 0
        self.start_time = None
        self.end_time = None

        # Initialize Vertex AI
        PROJECT_ID = "fas-dev-kempner-2b75"
        model_index = 0
        LOCATION = ["us-central1", "us-east5"][model_index]
        vertexai.init(project=PROJECT_ID, location=LOCATION)
        MODEL_NAME = ["gemini-2.0-flash-001", "gemini-1.5-flash-002"][
            model_index
        ]  # "gemini-2.0-flash-001" or "gemini-1.5-flash-002"
        self.model = GenerativeModel(model_name=MODEL_NAME)
        print(f"Using model: {MODEL_NAME}")

        self.prompt = (
            "Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. "
            "Focus on key actions, interactions, and movements. Include camera movements. "
            "Keep the summary brief and clear. "
            "Only include information that is certain, and avoid speculation or assumptions."
            "In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?"
        )
        # Lock for updating success and fail counts
        self.count_lock = threading.Lock()

        self.optional_keys = [
            "duration",
            "fps",
            "height",
            "width",
            "n_frames",
            "youtube_key_segment",
        ]

    def process_entries(self, records: List[Dict[str, Any]]):
        self.start_time = time.time()
        # Start worker threads
        for _ in range(self.num_workers):
            worker = threading.Thread(target=self._worker_process, daemon=True)
            worker.start()
            self.workers.append(worker)

        # Producer: read input lines and put them into the queue
        to_process_count = 0
        for data in records:
            entry = VideoEntry(
                mp4_path=data["video_path"],
            )
            # add optional keys to entry:
            for key in self.optional_keys:
                if key in data:
                    entry.__dict__[key] = data[key]
            self.entry_queue.put(entry)
            to_process_count += 1

        if to_process_count == 0:
            print("No new entries to process. All done!")
            # Even if none, still send sentinels to avoid blocking
            for _ in range(self.num_workers):
                self.entry_queue.put(None)
            return

        # Add sentinel values to signal workers to stop
        for _ in range(self.num_workers):
            self.entry_queue.put(None)

        # Wait for all workers to finish
        for worker in self.workers:
            worker.join()

        # Collect results
        results = []
        while not self.results_queue.empty():
            result = self.results_queue.get()
            # Only append results that aren't error messages
            if not result.caption.startswith("Error"):
                results.append(result)

        # Append results to output file
        with open(self.output_file, "a", encoding="utf-8") as f:
            for result in results:
                obj = {"video_path": result.mp4_path, "caption": result.caption}
                for key in self.optional_keys:
                    if key in result.__dict__ and result.__dict__[key] is not None:
                        obj[key] = result.__dict__[key]
                f.write(json.dumps(obj) + "\n")

        self.end_time = time.time()
        total_time = self.end_time - self.start_time
        print(f"Processed {len(results)} entries successfully.")
        print(f"Failed on {self.fail_count} entries.")
        print(f"Total time: {total_time:.2f} seconds.")
        if to_process_count > 0:
            print(f"Throughput: {to_process_count / total_time:.2f} videos/second.")
        print(f"Output file: {self.output_file}")

    def _read_video_file(self, file_path):
        """Read video file and convert it to base64."""
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Video file not found: {file_path}")
        with open(file_path, "rb") as video_file:
            return base64.b64encode(video_file.read()).decode("utf-8")

    def get_gemini_caption(self, video_path: str) -> str:
        """Generate a caption for a single video using Gemini Flash."""
        video_data = self._read_video_file(video_path)
        video_part = Part.from_data(data=video_data, mime_type="video/mp4")
        try:
            response = self.model.generate_content(
                [video_part, self.prompt],
                # generation_config={
                #     "max_output_tokens": 1024,
                #     "temperature": 0.4
                # },
                stream=False,
            )
            return response.text
        except Exception as e:
            print(f"Error from Gemini API: {e}")
            return f"Error from Gemini API: {e}"

    def _process_single_entry(self, entry: VideoEntry) -> CaptionResult:
        caption = self.get_gemini_caption(entry.mp4_path)

        ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption)
        for key in self.optional_keys:
            if key in entry.__dict__ and entry.__dict__[key] is not None:
                ret_result.__dict__[key] = entry.__dict__[key]
        return ret_result

    def _worker_process(self):
        while True:
            entry = self.entry_queue.get()
            if entry is None:  # Check for sentinel value
                break
            if self.entry_queue.qsize() % 100 == 0:
                print(
                    f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue."
                )
                gc_s_time = time.time()
                num_gc = gc.collect()
                gc_e_time = time.time()
                print(
                    f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects"
                )
            try:
                result = self._process_single_entry(entry)
                # Check if result is error. If not, add to results_queue.
                if not result.caption.startswith("Error"):
                    with self.count_lock:
                        self.success_count += 1
                    self.results_queue.put(result)
                else:
                    with self.count_lock:
                        self.fail_count += 1
                    print(f"Skipping {entry.mp4_path} due to error in captioning.")
            except Exception as e:
                with self.count_lock:
                    self.fail_count += 1
                print(f"Error processing {entry.mp4_path}: {str(e)}")
                traceback.print_exc()
            finally:
                self.entry_queue.task_done()