hzy commited on
Commit
608eb1a
·
verified ·
1 Parent(s): b39a6ee

Initial upload of all project files

Browse files
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.14.0/8806091925620224500 ADDED
Binary file (95 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
compute_video_emb.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from transformers import AutoModel, AutoProcessor
7
+ import torch.multiprocessing as mp
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import glob
10
+
11
+ # --- 配置 ---
12
+ MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
13
+ BATCH_SIZE = 1024 # 根据你的 GPU VRAM 调整
14
+
15
+ def parse_arguments():
16
+ """解析命令行参数"""
17
+ parser = argparse.ArgumentParser(
18
+ description="步骤 1: 使用 SigLIP (多GPU) 预计算所有视频帧的嵌入."
19
+ )
20
+ parser.add_argument(
21
+ "--frames-path",
22
+ "-fp",
23
+ type=str,
24
+ required=True,
25
+ help="包含所有视频帧文件夹的基础目录的绝对路径。",
26
+ )
27
+ parser.add_argument(
28
+ "--output-dir",
29
+ "-o",
30
+ type=str,
31
+ required=True,
32
+ help="用于保存嵌入.pt文件的输出目录路径。",
33
+ )
34
+ return parser.parse_args()
35
+
36
+ class FrameDataset(Dataset):
37
+ """一个用于高效加载视频帧的PyTorch Dataset"""
38
+ def __init__(self, frame_paths):
39
+ self.frame_paths = frame_paths
40
+
41
+ def __len__(self):
42
+ return len(self.frame_paths)
43
+
44
+ def __getitem__(self, idx):
45
+ path = self.frame_paths[idx]
46
+ try:
47
+ image = Image.open(path).convert("RGB")
48
+ return image
49
+ except Exception:
50
+ return None
51
+
52
+ def collate_fn(batch):
53
+ """自定义collate函数,用于从批次中过滤掉None值"""
54
+ batch = [item for item in batch if item is not None]
55
+ if not batch:
56
+ return None
57
+ return batch
58
+
59
+ def process_video_chunk(args_tuple):
60
+ """
61
+ 工作函数,用于在特定GPU上处理一批视频。
62
+ """
63
+ video_dirs_chunk, frames_base_path, gpu_id, output_dir = args_tuple
64
+ device = f"cuda:{gpu_id}"
65
+
66
+ # 在工作进程中为指定的GPU加载模型和处理器
67
+ model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
68
+ processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
69
+
70
+ progress_bar = tqdm(video_dirs_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
71
+
72
+ for video_dir in progress_bar:
73
+ video_name = os.path.basename(video_dir)
74
+ output_path = os.path.join(output_dir, f"{video_name}.pt")
75
+
76
+ # 如果文件已存在,则跳过,以支持断点续算
77
+ if os.path.exists(output_path):
78
+ progress_bar.write(f"Skipping {video_name}, embeddings already exist.")
79
+ continue
80
+
81
+ frame_files = [f for f in os.listdir(video_dir) if f.endswith(".jpg")]
82
+ if not frame_files:
83
+ continue
84
+ frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
85
+ frame_paths = [os.path.join(video_dir, f) for f in frame_files]
86
+
87
+ try:
88
+ with torch.no_grad():
89
+ dataset = FrameDataset(frame_paths)
90
+ loader = DataLoader(
91
+ dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0,
92
+ pin_memory=True, collate_fn=collate_fn
93
+ )
94
+
95
+ all_frame_embeddings = []
96
+ for image_batch in loader:
97
+ if image_batch is None:
98
+ continue
99
+
100
+ image_inputs = processor(images=image_batch, return_tensors="pt").to(device)
101
+ frame_embeddings = model.get_image_features(**image_inputs)
102
+ all_frame_embeddings.append(frame_embeddings)
103
+
104
+ if not all_frame_embeddings:
105
+ continue
106
+
107
+ all_frame_embeddings = torch.cat(all_frame_embeddings, dim=0)
108
+
109
+ # 将张量移动到CPU以便保存,避免后续加载时出现CUDA问题
110
+ data_to_save = {
111
+ 'filenames': frame_files,
112
+ 'embeddings': all_frame_embeddings.cpu()
113
+ }
114
+ torch.save(data_to_save, output_path)
115
+
116
+ except Exception as e:
117
+ progress_bar.write(f"Error on GPU-{gpu_id} for video '{video_name}': {e}")
118
+
119
+ def main():
120
+ """主函数,用于协调多GPU处理"""
121
+ args = parse_arguments()
122
+
123
+ num_gpus = torch.cuda.device_count()
124
+ if num_gpus == 0:
125
+ print("错误: 未找到启用CUDA的GPU。正在退出。")
126
+ exit(1)
127
+
128
+ print(f"找到 {num_gpus} 个GPU。开始并行处理...")
129
+
130
+ # 创建输出目录
131
+ os.makedirs(args.output_dir, exist_ok=True)
132
+
133
+ # 获取所有唯一的视频目录
134
+ video_dirs = [d for d in glob.glob(os.path.join(args.frames_path, '*')) if os.path.isdir(d)]
135
+
136
+ if not video_dirs:
137
+ print(f"错误: 在 {args.frames_path} 中未找到视频目录。")
138
+ return
139
+
140
+ # 将视频目录分成块,每个GPU一块
141
+ chunk_size = (len(video_dirs) + num_gpus - 1) // num_gpus
142
+ video_chunks = [video_dirs[i:i + chunk_size] for i in range(0, len(video_dirs), chunk_size)]
143
+
144
+ # 为每个工作进程准���参数
145
+ process_args = [(video_chunks[i], args.frames_path, i, args.output_dir) for i in range(len(video_chunks))]
146
+
147
+ with mp.Pool(processes=num_gpus) as pool:
148
+ pool.map(process_video_chunk, process_args)
149
+
150
+ print("\n所有视频帧嵌入已计算并保存。")
151
+
152
+ if __name__ == "__main__":
153
+ mp.set_start_method('spawn', force=True)
154
+ main()
main_adaptive_sampling.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ from datetime import datetime
8
+ from functools import partial
9
+ from openai import AzureOpenAI, OpenAI
10
+ from volcenginesdkarkruntime import Ark
11
+ from multiprocessing import Pool, Manager, Lock
12
+
13
+ # New prompt template for multiple-choice questions with reasoning
14
+ REASONING_MULTIPLE_CHOICE_TEMPLATE = """
15
+ You are an AI assistant evaluating video frames to answer a multiple-choice question.
16
+ The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D).
17
+
18
+ First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion.
19
+ After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
20
+
21
+ Your output should follow this format exactly:
22
+ <Your step-by-step reasoning here>
23
+ ```json
24
+ {"answer": "A"}
25
+ ```
26
+ Do not include any other text after the JSON block.
27
+ """
28
+
29
+
30
+ def parse_arguments():
31
+ """
32
+ Parse command line arguments for evaluation configuration.
33
+
34
+ Returns:
35
+ argparse.Namespace: Parsed command line arguments
36
+ """
37
+ parser = argparse.ArgumentParser(
38
+ description="Video QA Evaluation with Pre-computed Similarity Frame Selection"
39
+ )
40
+
41
+ # Model configuration
42
+ parser.add_argument(
43
+ "--target-model",
44
+ "-tm",
45
+ type=str,
46
+ required=True,
47
+ help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)",
48
+ )
49
+
50
+ # Data configuration
51
+ parser.add_argument(
52
+ "--frame-num",
53
+ "-fn",
54
+ type=int,
55
+ default=32,
56
+ help="Number of most similar frames to select for each video (default: 32)",
57
+ )
58
+ parser.add_argument(
59
+ "--frames-path",
60
+ "-fp",
61
+ type=str,
62
+ required=True,
63
+ help="Absolute path to the base directory containing video frame folders.",
64
+ )
65
+ parser.add_argument(
66
+ "--data-file",
67
+ "-df",
68
+ type=str,
69
+ required=True,
70
+ help="Absolute path to the JSON file containing the evaluation dataset.",
71
+ )
72
+ # --- MODIFIED ARGUMENT ---
73
+ parser.add_argument(
74
+ "--similarity-file",
75
+ "-sf",
76
+ type=str,
77
+ required=True,
78
+ help="Absolute path to the pre-computed similarity JSON file (e.g., lv_bench_similarity.json).",
79
+ )
80
+
81
+ # Processing configuration
82
+ parser.add_argument(
83
+ "--max-retry-times",
84
+ "-mr",
85
+ type=int,
86
+ default=10,
87
+ help="Maximum number of retries for API calls (default: 10)",
88
+ )
89
+ parser.add_argument(
90
+ "--pool-processes",
91
+ "-pp",
92
+ type=int,
93
+ default=20,
94
+ help="Number of parallel processes for evaluation (default: 20)",
95
+ )
96
+
97
+ # API configuration
98
+ parser.add_argument(
99
+ "--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
100
+ )
101
+ parser.add_argument(
102
+ "--api_key", type=str, required=True, help="Azure OpenAI API key."
103
+ )
104
+
105
+ return parser.parse_args()
106
+
107
+
108
+ def save_json_file(data, output_file):
109
+ """
110
+ Save data to a JSON file.
111
+ """
112
+ with open(output_file, "w", encoding="utf-8") as f:
113
+ json.dump(data, f, indent=4)
114
+
115
+
116
+ def extract_json_from_response(response):
117
+ """
118
+ Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block.
119
+ """
120
+ if not response:
121
+ return None
122
+ try:
123
+ match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
124
+ if match:
125
+ json_str = match.group(1)
126
+ return json.loads(json_str)
127
+ return None
128
+ except (json.JSONDecodeError, IndexError):
129
+ return None
130
+
131
+
132
+ def calculate_metrics(results):
133
+ """
134
+ Calculate evaluation metrics from the results.
135
+ """
136
+ total_samples = len(results)
137
+ if total_samples == 0:
138
+ return {
139
+ "total_samples": 0,
140
+ "answered_samples": 0,
141
+ "correct_answers": 0,
142
+ "accuracy": 0.0,
143
+ }
144
+
145
+ answered_samples = sum(1 for x in results if x.get("model_answer") is not None)
146
+ correct_answers = sum(1 for x in results if x.get("is_correct"))
147
+
148
+ accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
149
+
150
+ return {
151
+ "total_samples": total_samples,
152
+ "answered_samples": answered_samples,
153
+ "correct_answers": correct_answers,
154
+ "accuracy": accuracy,
155
+ }
156
+
157
+
158
+ def call_single_model(client, messages, model, item_id, max_retry_times):
159
+ """
160
+ Make a single API call to the specified model with retry logic.
161
+ """
162
+ if "doubao" in model:
163
+ max_tokens = 32768
164
+ else:
165
+ max_tokens = 65535
166
+ retry_times = 0
167
+ while retry_times < max_retry_times:
168
+ try:
169
+ completion = client.chat.completions.create(
170
+ model=model, messages=messages, max_tokens=max_tokens
171
+ )
172
+ return completion.choices[0].message.content
173
+ except Exception as e:
174
+ retry_times += 1
175
+ print(
176
+ f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
177
+ )
178
+ if retry_times == max_retry_times:
179
+ error_log_file = f"error_log_{model.replace('/', '_')}.txt"
180
+ with open(error_log_file, "a") as f:
181
+ f.write(
182
+ f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n"
183
+ )
184
+ return None
185
+ time.sleep(5)
186
+
187
+
188
+ def evaluate_single_item(
189
+ data_item, frames, target_model, api_key, base_url, max_retry_times
190
+ ):
191
+ """
192
+ Evaluate a single data item using the target model.
193
+ """
194
+ if "ark" in base_url:
195
+ client = Ark(base_url=base_url, api_key=api_key)
196
+ elif "aliyun" in base_url or "127.0.0.1" in base_url:
197
+ client = OpenAI(api_key=api_key, base_url=base_url)
198
+ else:
199
+ client = AzureOpenAI(
200
+ api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
201
+ )
202
+
203
+ messages = [
204
+ {"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE},
205
+ {
206
+ "role": "user",
207
+ "content": [
208
+ {"type": "text", "text": "Here are the video frames:"},
209
+ *frames,
210
+ {"type": "text", "text": f"Question: {data_item['question']}"},
211
+ ],
212
+ },
213
+ ]
214
+
215
+ response = call_single_model(
216
+ client, messages, target_model, data_item["key"], max_retry_times
217
+ )
218
+
219
+ is_correct = False
220
+ model_answer_cleaned = None
221
+ parsed_json = None
222
+
223
+ if response:
224
+ parsed_json = extract_json_from_response(response)
225
+ if parsed_json and "answer" in parsed_json:
226
+ model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
227
+ gold_answer = data_item["answer"].strip().upper()
228
+ if model_answer_cleaned == gold_answer:
229
+ is_correct = True
230
+
231
+ return {
232
+ **data_item,
233
+ "model_reasoning_and_answer": response,
234
+ "model_answer_raw": parsed_json.get("answer") if parsed_json else None,
235
+ "model_answer": model_answer_cleaned,
236
+ "is_correct": is_correct,
237
+ }
238
+
239
+
240
+ def encode_image(image_path):
241
+ """
242
+ Encode an image file to base64 string.
243
+ """
244
+ with open(image_path, "rb") as image_file:
245
+ return base64.b64encode(image_file.read()).decode("utf-8")
246
+
247
+
248
+ # --- MODIFIED: New function for selecting frames based on pre-computed similarity file ---
249
+ def process_frames_from_similarity_file(
250
+ frames_base_path, frame_num, data_item, similarity_data
251
+ ):
252
+ """
253
+ Select and encode the top N frames using a pre-computed similarity file.
254
+ """
255
+ item_key = data_item["key"]
256
+ question_uid = str(data_item["uid"])
257
+
258
+ # Retrieve the sorted list of frame filenames for the current question
259
+ sorted_filenames = similarity_data.get(question_uid)
260
+
261
+ if not sorted_filenames:
262
+ print(
263
+ f"Warning: No similarity data found for question UID '{question_uid}', skipping."
264
+ )
265
+ return []
266
+
267
+ try:
268
+ # Select the top N filenames
269
+ num_frames_to_select = min(frame_num, len(sorted_filenames))
270
+ selected_filenames = sorted_filenames[:num_frames_to_select]
271
+ selected_ids = [int(f.split(".")[0].split("_")[-1]) for f in selected_filenames]
272
+ selected_ids = sorted(selected_ids)
273
+ selected_filenames = [f"frame_{i:06d}.jpg" for i in selected_ids]
274
+
275
+ # Construct full paths for the selected frames
276
+ video_frames_path = os.path.join(frames_base_path, item_key)
277
+ sampled_paths = [os.path.join(video_frames_path, f) for f in selected_filenames]
278
+
279
+ # Encode the selected frames
280
+ base64_images = [encode_image(path) for path in sampled_paths]
281
+
282
+ return [
283
+ {
284
+ "type": "image_url",
285
+ "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"},
286
+ }
287
+ for b64_img in base64_images
288
+ ]
289
+ except Exception as e:
290
+ print(f"Error during frame processing for key '{item_key}': {e}")
291
+ return []
292
+
293
+
294
+ def process_single_data(
295
+ data_item,
296
+ args,
297
+ shared_results,
298
+ progress_counter,
299
+ total_items,
300
+ locks,
301
+ similarity_data,
302
+ ):
303
+ """
304
+ Process a single data item in a multiprocessing context.
305
+ """
306
+ item_key = data_item["key"]
307
+ try:
308
+ # --- MODIFIED: Call the new frame selection function ---
309
+ frames = process_frames_from_similarity_file(
310
+ args.frames_path, args.frame_num, data_item, similarity_data
311
+ )
312
+
313
+ if not frames:
314
+ raise ValueError(
315
+ f"No frames were processed from similarity file for key '{item_key}'"
316
+ )
317
+
318
+ result = evaluate_single_item(
319
+ data_item,
320
+ frames,
321
+ args.target_model,
322
+ args.api_key,
323
+ args.base_url,
324
+ args.max_retry_times,
325
+ )
326
+
327
+ if result is not None:
328
+ with locks["results"]:
329
+ shared_results.append(result)
330
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[
331
+ 0
332
+ ]
333
+ model_name_safe = args.target_model.replace("/", "_")
334
+ output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar"
335
+ results_output_file = f"{output_prefix}_results.json"
336
+ save_json_file(list(shared_results), results_output_file)
337
+
338
+ except Exception as e:
339
+ print(f"Error processing video key {item_key}: {str(e)}")
340
+ with locks["file"]:
341
+ error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
342
+ with open(error_log_file, "a") as f:
343
+ f.write(f"Critical error processing video key {item_key}: {str(e)}\n")
344
+ finally:
345
+ with locks["counter"]:
346
+ progress_counter.value += 1
347
+ print(
348
+ f"\rProcessed: {progress_counter.value}/{total_items} videos...",
349
+ end="",
350
+ flush=True,
351
+ )
352
+
353
+
354
+ def load_test_data(json_file):
355
+ """
356
+ Load test data from a JSON file.
357
+ """
358
+ try:
359
+ with open(json_file, "r", encoding="utf-8") as f:
360
+ return json.load(f)
361
+ except FileNotFoundError:
362
+ print(f"Error: Data file not found at {json_file}")
363
+ exit(1)
364
+ except json.JSONDecodeError:
365
+ print(f"Error: Could not decode JSON from {json_file}")
366
+ exit(1)
367
+
368
+
369
+ def main():
370
+ """
371
+ Main function to run the video QA evaluation framework.
372
+ """
373
+ args = parse_arguments()
374
+
375
+ print("--- Evaluation Configuration ---")
376
+ print(f"Target Model: {args.target_model}")
377
+ print(f"Frames to Sample (by pre-computed similarity): {args.frame_num}")
378
+ print(f"Frames Base Path: {args.frames_path}")
379
+ print(f"Similarity File: {args.similarity_file}") # Print new arg
380
+ print(f"Data File: {args.data_file}")
381
+ print(f"Parallel Processes: {args.pool_processes}")
382
+ print("---------------------------------")
383
+
384
+ error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
385
+ with open(error_log_file, "w") as f:
386
+ f.write(
387
+ f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n"
388
+ )
389
+
390
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
391
+ model_name_safe = args.target_model.replace("/", "_")
392
+ output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar"
393
+
394
+ results_output_file = f"{output_prefix}_results.json"
395
+ metrics_output_file = f"{output_prefix}_metrics.json"
396
+
397
+ # Load test data and similarity data
398
+ test_data = load_test_data(args.data_file)
399
+ try:
400
+ with open(args.similarity_file, "r", encoding="utf-8") as f:
401
+ similarity_data = json.load(f)
402
+ except FileNotFoundError:
403
+ print(f"Error: Similarity file not found at {args.similarity_file}")
404
+ exit(1)
405
+
406
+ total_videos = len(test_data)
407
+ print(f"\nLoaded {total_videos} videos to process.")
408
+
409
+ with Manager() as manager:
410
+ shared_results = manager.list()
411
+ progress_counter = manager.Value("i", 0)
412
+ locks = {
413
+ "results": manager.Lock(),
414
+ "file": manager.Lock(),
415
+ "counter": manager.Lock(),
416
+ }
417
+
418
+ # Create a partial function with fixed arguments for the worker pool
419
+ process_func = partial(
420
+ process_single_data,
421
+ args=args,
422
+ shared_results=shared_results,
423
+ progress_counter=progress_counter,
424
+ total_items=total_videos,
425
+ locks=locks,
426
+ similarity_data=similarity_data,
427
+ )
428
+
429
+ # Run processing in parallel
430
+ with Pool(processes=args.pool_processes) as pool:
431
+ pool.map(process_func, test_data)
432
+
433
+ all_results = list(shared_results)
434
+
435
+ print(f"\n\nProcessing complete for model: {args.target_model}")
436
+
437
+ final_metrics = calculate_metrics(all_results)
438
+ save_json_file(final_metrics, metrics_output_file)
439
+ print(f"\nMetrics saved to: {metrics_output_file}")
440
+ print(json.dumps(final_metrics, indent=4))
441
+
442
+ save_json_file(all_results, results_output_file)
443
+ print(f"Detailed results saved to: {results_output_file}")
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()
main_agent.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ import traceback
8
+ from datetime import datetime
9
+ from functools import partial
10
+ from openai import AzureOpenAI, OpenAI
11
+ from volcenginesdkarkruntime import Ark
12
+ import concurrent.futures
13
+ from tqdm import tqdm
14
+
15
+ # New system prompt for the agent
16
+ AGENT_SYSTEM_PROMPT = """
17
+ You are an intelligent AI assistant specialized in video question answering.
18
+ Your task is to answer a multiple-choice question based on a video.
19
+
20
+ You must use the `get_frames_by_id` tool to request specific frames to view.
21
+ You will be told the total number of frames available in the video (e.g., "The video has 1250 frames, numbered 1 to 1250.").
22
+
23
+ Your strategy should be efficient:
24
+ 1. Based on the task query, think about which part of the video will be related, and then get the frames of this part. If the query’s description is fairly general and you can’t effectively infer the temporal regions where the target visual evidence might appear, you can first uniformly sample some frames for analysis to identify the time intervals where the target visual evidence is likely to appear.
25
+ 2. Analyze the retrieved frames and the user's question.
26
+ 3. If you don't have enough information, form a hypothesis about where the answer might be and use the tool again to request more specific frames from that segment.
27
+ 4. Continue this process of reasoning and tool use until you are confident in your answer. Avoid requesting all frames at once.
28
+ 5. Please make sure that you find the relevant visual cues and then answer the question instead of guessing the answer.
29
+ 6. You can access 10 frames at most in each tool call.
30
+
31
+ Please note that if you have insufficient visual information at the beginning, you can first sample more frames uniformly to understand the video (e.g., sampling 10 frames per tool call). You can then gradually refine the subsequent steps and adopt a coarse-to-fine strategy overall.
32
+ For example, the question is "What is the main subject of the video?"
33
+ You can first sample 10 frames uniformly from the video (e.g., frame 100, 200, ..., 1200).
34
+ After analyzing these frames, you might notice that the main subject is a person in the middle of the screen (between frame 500 and 600).
35
+ You can then sample more frames from this region (e.g., frame 500, 520, ..., 590) to get more detailed information.
36
+ Finally, you can reason based on the visual cues you have gathered and provide the final answer.
37
+ This process might be multi-turn.
38
+
39
+ After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
40
+
41
+ Remember that You can access 10 frames at most in each tool call.
42
+
43
+ Your output should follow this format exactly:
44
+ <Your step-by-step reasoning here>
45
+ ```json
46
+ {"answer": "X"}
47
+ ```
48
+ Do not include any other text after the JSON block.
49
+ """
50
+
51
+ # Tool schema for the get_frames_by_id function
52
+ GET_FRAMES_TOOL_SCHEMA = {
53
+ "type": "function",
54
+ "function": {
55
+ "name": "get_frames_by_id",
56
+ "description": "Retrieves specific video frames by their numerical IDs. Use this to get visual information from the video.",
57
+ "parameters": {
58
+ "type": "object",
59
+ "properties": {
60
+ "frame_ids": {
61
+ "type": "array",
62
+ "items": {"type": "integer"},
63
+ "description": "A list of frame numbers to retrieve. You can access 10 frames at most in each tool call.",
64
+ },
65
+ },
66
+ "required": ["frame_ids"],
67
+ },
68
+ },
69
+ }
70
+
71
+
72
+ def parse_arguments():
73
+ """
74
+ Parse command line arguments for evaluation configuration.
75
+ """
76
+ parser = argparse.ArgumentParser(
77
+ description="Video QA Evaluation Framework with Agentic Frame Selection (Refactored)"
78
+ )
79
+ parser.add_argument(
80
+ "--target-model",
81
+ "-tm",
82
+ type=str,
83
+ required=True,
84
+ help="Model to be evaluated (e.g., gpt-4o)",
85
+ )
86
+ parser.add_argument(
87
+ "--frames-path",
88
+ "-fp",
89
+ type=str,
90
+ required=True,
91
+ help="Absolute path to the base directory for video frames.",
92
+ )
93
+ parser.add_argument(
94
+ "--data-file",
95
+ "-df",
96
+ type=str,
97
+ required=True,
98
+ help="Absolute path to the JSON evaluation dataset.",
99
+ )
100
+ parser.add_argument(
101
+ "--max-retry-times",
102
+ "-mr",
103
+ type=int,
104
+ default=10,
105
+ help="Maximum retries for API calls.",
106
+ )
107
+ parser.add_argument(
108
+ "--pool-processes",
109
+ "-pp",
110
+ type=int,
111
+ default=20,
112
+ help="Number of parallel processes.",
113
+ )
114
+ parser.add_argument(
115
+ "--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
116
+ )
117
+ parser.add_argument(
118
+ "--api_key", type=str, required=True, help="Azure OpenAI API key."
119
+ )
120
+ return parser.parse_args()
121
+
122
+
123
+ def save_json_file(data, output_file):
124
+ """Saves data to a JSON file."""
125
+ with open(output_file, "w", encoding="utf-8") as f:
126
+ json.dump(data, f, indent=4)
127
+
128
+
129
+ def extract_json_from_response(response):
130
+ """Extracts a JSON object from a model's response string."""
131
+ if not response:
132
+ return None
133
+ match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
134
+ if match:
135
+ try:
136
+ return json.loads(match.group(1))
137
+ except (json.JSONDecodeError, IndexError):
138
+ return None
139
+ return None
140
+
141
+
142
+ def calculate_metrics(results):
143
+ """Calculates accuracy and other metrics from evaluation results."""
144
+ # Filter out potential error results before calculating
145
+ valid_results = [r for r in results if "error" not in r]
146
+ total_samples = len(valid_results)
147
+
148
+ if total_samples == 0:
149
+ return {
150
+ "total_samples": 0,
151
+ "answered_samples": 0,
152
+ "correct_answers": 0,
153
+ "accuracy": 0.0,
154
+ }
155
+
156
+ answered_samples = sum(
157
+ 1 for x in valid_results if x.get("model_answer") is not None
158
+ )
159
+ correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
160
+ accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
161
+
162
+ return {
163
+ "total_samples": total_samples,
164
+ "answered_samples": answered_samples,
165
+ "correct_answers": correct_answers,
166
+ "accuracy": accuracy,
167
+ }
168
+
169
+
170
+ def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
171
+ """Makes a single API call with retry logic and tool support."""
172
+ if "o4" in model:
173
+ params = {"model": model, "messages": messages, "max_tokens": 65535}
174
+ elif "Qwen" in model:
175
+ params = {
176
+ "model": model,
177
+ "messages": messages,
178
+ "max_tokens": 2048,
179
+ "temperature": 0,
180
+ }
181
+ else:
182
+ params = {"model": model, "messages": messages, "max_tokens": 32768}
183
+ if tools:
184
+ params["tools"] = tools
185
+ params["tool_choice"] = "auto"
186
+
187
+ retry_times = 0
188
+ while retry_times < max_retry_times:
189
+ try:
190
+ completion = client.chat.completions.create(**params)
191
+ return completion.choices[0].message
192
+ except Exception as e:
193
+ retry_times += 1
194
+ print(
195
+ f"API Error for item {item_id}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
196
+ )
197
+ if retry_times == max_retry_times:
198
+ # Instead of writing to a file here, we'll let the worker return the error
199
+ raise e # Reraise the exception to be caught by the worker's main try-except block
200
+ time.sleep(5)
201
+
202
+
203
+ def get_frames_by_id(frame_ids: list, all_frame_paths: list):
204
+ """Tool implementation: Retrieves and formats frames based on a list of IDs."""
205
+ retrieved_frames = []
206
+ frame_map = {
207
+ int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p
208
+ for p in all_frame_paths
209
+ if re.search(r"frame_(\d+)\.jpg", os.path.basename(p))
210
+ }
211
+
212
+ for fid in frame_ids:
213
+ path = frame_map.get(fid)
214
+ if path and os.path.exists(path):
215
+ b64_image = encode_image(path)
216
+ retrieved_frames.append(
217
+ {
218
+ "type": "image_url",
219
+ "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
220
+ }
221
+ )
222
+ return retrieved_frames
223
+
224
+
225
+ def evaluate_single_item_agentic(
226
+ data_item, all_frame_paths, target_model, api_key, base_url, max_retry_times
227
+ ):
228
+ """Evaluates a single item using an agentic loop for dynamic frame selection."""
229
+ if "ark" in base_url:
230
+ client = Ark(
231
+ base_url=base_url,
232
+ api_key=api_key,
233
+ )
234
+ elif "aliyun" in base_url or "127.0.0.1" in base_url:
235
+ client = OpenAI(api_key=api_key, base_url=base_url)
236
+ else:
237
+ client = AzureOpenAI(
238
+ api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
239
+ )
240
+
241
+ tools = [GET_FRAMES_TOOL_SCHEMA]
242
+ available_functions = {"get_frames_by_id": get_frames_by_id}
243
+
244
+ total_frames = len(all_frame_paths)
245
+ minutes = data_item["video_info"]["duration_minutes"]
246
+ seconds = int(minutes * 60)
247
+ initial_prompt = (
248
+ f"The video has {total_frames} frames, numbered 1 to {total_frames}. This video is {seconds} seconds long. "
249
+ f"Please answer the following question:\n{data_item['question']}"
250
+ )
251
+
252
+ messages = [
253
+ {"role": "system", "content": AGENT_SYSTEM_PROMPT},
254
+ {"role": "user", "content": initial_prompt},
255
+ ]
256
+
257
+ response_content = None
258
+ max_tool_calls = 10
259
+
260
+ for i in range(max_tool_calls):
261
+ response_message = call_single_model(
262
+ client,
263
+ messages,
264
+ target_model,
265
+ data_item["key"],
266
+ max_retry_times,
267
+ tools=tools,
268
+ )
269
+ if response_message is None:
270
+ return None
271
+
272
+ messages.append(response_message.model_dump())
273
+
274
+ if response_message.tool_calls:
275
+ for tool_call in response_message.tool_calls:
276
+ function_name = tool_call.function.name
277
+ function_to_call = available_functions.get(function_name)
278
+ if function_to_call:
279
+ function_args = json.loads(tool_call.function.arguments)
280
+ retrieved_frames = function_to_call(
281
+ **function_args, all_frame_paths=all_frame_paths
282
+ )
283
+ tool_response_content = [
284
+ {
285
+ "type": "text",
286
+ "text": f"Here are the frames you requested (IDs: {function_args.get('frame_ids', [])}).",
287
+ }
288
+ ]
289
+ tool_response_content.extend(retrieved_frames)
290
+ messages.append(
291
+ {
292
+ "tool_call_id": tool_call.id,
293
+ "role": "tool",
294
+ "name": function_name,
295
+ "content": json.dumps(
296
+ {
297
+ "status": "success",
298
+ "retrieved_frame_count": len(retrieved_frames),
299
+ }
300
+ ),
301
+ }
302
+ )
303
+ messages.append({"role": "user", "content": tool_response_content})
304
+ else:
305
+ response_content = response_message.content
306
+ break
307
+
308
+ if response_content is None and response_message and response_message.tool_calls:
309
+ print(
310
+ f"\nMax tool calls reached for item {data_item['key']}. Forcing a final answer."
311
+ )
312
+ final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far."
313
+ messages.append({"role": "user", "content": final_prompt})
314
+ final_response_message = call_single_model(
315
+ client,
316
+ messages,
317
+ target_model,
318
+ data_item["key"],
319
+ max_retry_times,
320
+ tools=None,
321
+ )
322
+ if final_response_message:
323
+ messages.append(final_response_message)
324
+ response_content = final_response_message.content
325
+ elif response_content is None and response_message:
326
+ response_content = response_message.content
327
+
328
+ is_correct = False
329
+ model_answer_cleaned = None
330
+ parsed_json = extract_json_from_response(response_content)
331
+ if parsed_json and "answer" in parsed_json:
332
+ model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
333
+ gold_answer = data_item["answer"].strip().upper()
334
+ if model_answer_cleaned == gold_answer:
335
+ is_correct = True
336
+ return {
337
+ **data_item,
338
+ "agent_conversation": [
339
+ msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages
340
+ ],
341
+ "model_reasoning_and_answer": response_content,
342
+ "model_answer": model_answer_cleaned,
343
+ "is_correct": is_correct,
344
+ }
345
+
346
+
347
+ def encode_image(image_path):
348
+ """Encodes an image file to a base64 string."""
349
+ with open(image_path, "rb") as image_file:
350
+ return base64.b64encode(image_file.read()).decode("utf-8")
351
+
352
+
353
+ def process_single_data(data_item, args):
354
+ """
355
+ Main processing function for a single video.
356
+ This function is executed by each worker process. It is self-contained.
357
+ """
358
+ item_key = data_item["key"]
359
+ try:
360
+ specific_frames_path = os.path.join(args.frames_path, item_key)
361
+ if not os.path.isdir(specific_frames_path):
362
+ raise FileNotFoundError(f"Frame directory not found for key '{item_key}'")
363
+
364
+ all_frame_paths = sorted(
365
+ [
366
+ os.path.join(specific_frames_path, f)
367
+ for f in os.listdir(specific_frames_path)
368
+ if f.endswith(".jpg")
369
+ ],
370
+ key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
371
+ )
372
+
373
+ if not all_frame_paths:
374
+ raise FileNotFoundError(f"No frames found for key '{item_key}'")
375
+
376
+ # The core evaluation logic is called here
377
+ result = evaluate_single_item_agentic(
378
+ data_item,
379
+ all_frame_paths,
380
+ args.target_model,
381
+ args.api_key,
382
+ args.base_url,
383
+ args.max_retry_times,
384
+ )
385
+ return result
386
+
387
+ except Exception as e:
388
+ # If any error occurs, catch it and return an error dictionary.
389
+ # This prevents the worker process from crashing and allows the main
390
+ # process to log the error gracefully.
391
+ print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}")
392
+ traceback.print_exc()
393
+ return {
394
+ "key": item_key,
395
+ "uid": data_item.get("uid"),
396
+ "error": str(e),
397
+ "traceback": traceback.format_exc(),
398
+ }
399
+
400
+
401
+ def load_test_data(json_file):
402
+ """Loads the evaluation data from a JSON file."""
403
+ try:
404
+ with open(json_file, "r", encoding="utf-8") as f:
405
+ return json.load(f)
406
+ except FileNotFoundError:
407
+ print(f"Error: Data file not found: {json_file}")
408
+ exit(1)
409
+ except json.JSONDecodeError:
410
+ print(f"Error: Malformed JSON in {json_file}")
411
+ exit(1)
412
+
413
+
414
+ def main():
415
+ """Main function to orchestrate the evaluation framework."""
416
+ args = parse_arguments()
417
+
418
+ print("--- Agentic Video QA Evaluation (Refactored) ---")
419
+ print(f"Target Model: {args.target_model}")
420
+ print(f"Frames Base Path: {args.frames_path}")
421
+ print(f"Data File: {args.data_file}")
422
+
423
+ model_name_safe = args.target_model.replace("/", "_")
424
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
425
+
426
+ output_prefix = f"{model_name_safe}_{data_filename_base}_agent_results"
427
+ results_output_file = f"{output_prefix}.json"
428
+ metrics_output_file = f"{output_prefix}_metrics.json"
429
+ error_log_file = f"{output_prefix}_errors.log"
430
+
431
+ with open(error_log_file, "a", encoding="utf-8") as f:
432
+ f.write(
433
+ f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n"
434
+ )
435
+
436
+ all_test_data = load_test_data(args.data_file)
437
+ completed_ids = set()
438
+ existing_results = []
439
+
440
+ if os.path.exists(results_output_file):
441
+ try:
442
+ with open(results_output_file, "r", encoding="utf-8") as f:
443
+ existing_results = json.load(f)
444
+ if isinstance(existing_results, list):
445
+ completed_ids = {
446
+ item["uid"] for item in existing_results if "uid" in item
447
+ }
448
+ print(
449
+ f"Found {len(completed_ids)} completed tasks in '{results_output_file}'. Resuming..."
450
+ )
451
+ else:
452
+ existing_results = []
453
+ except (json.JSONDecodeError, IOError) as e:
454
+ print(f"Warning: Could not read results file: {e}. Starting fresh.")
455
+ existing_results = []
456
+
457
+ tasks_to_process = [
458
+ item for item in all_test_data if item.get("uid") not in completed_ids
459
+ ]
460
+
461
+ if not tasks_to_process:
462
+ print("All tasks are already completed. Calculating final metrics.")
463
+ final_metrics = calculate_metrics(existing_results)
464
+ save_json_file(final_metrics, metrics_output_file)
465
+ print(f"\nFinal metrics saved to: {metrics_output_file}")
466
+ print(json.dumps(final_metrics, indent=4))
467
+ return
468
+
469
+ print(
470
+ f"Total tasks: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}."
471
+ )
472
+
473
+ # This list will hold all results, both old and new.
474
+ all_results = list(existing_results)
475
+
476
+ # Using ProcessPoolExecutor for robust, modern multiprocessing.
477
+ with concurrent.futures.ProcessPoolExecutor(
478
+ max_workers=args.pool_processes
479
+ ) as executor:
480
+ # partial is used to pass the constant `args` to each call of process_single_data
481
+ func = partial(process_single_data, args=args)
482
+
483
+ # executor.map processes the tasks in parallel.
484
+ # tqdm provides a progress bar.
485
+ results_iterator = executor.map(func, tasks_to_process)
486
+
487
+ for result in tqdm(
488
+ results_iterator, total=len(tasks_to_process), desc="Processing Videos"
489
+ ):
490
+ if result:
491
+ if "error" in result:
492
+ # Log errors centrally
493
+ with open(error_log_file, "a", encoding="utf-8") as f:
494
+ f.write(f"Error on key {result.get('key', 'N/A')}:\n")
495
+ f.write(f" Error: {result['error']}\n")
496
+ f.write(f" Traceback: {result['traceback']}\n---\n")
497
+
498
+ # Append every result (success or error) to the main list
499
+ all_results.append(result)
500
+
501
+ # Periodically save results for resilience
502
+ if len(all_results) % 10 == 0:
503
+ save_json_file(all_results, results_output_file)
504
+
505
+ print("\n\nProcessing complete.")
506
+
507
+ # Final save of all combined results
508
+ save_json_file(all_results, results_output_file)
509
+ print(f"Detailed results saved to: {results_output_file}")
510
+
511
+ # Calculate and save final metrics
512
+ final_metrics = calculate_metrics(all_results)
513
+ save_json_file(final_metrics, metrics_output_file)
514
+ print(f"\nMetrics saved to: {metrics_output_file}")
515
+ print(json.dumps(final_metrics, indent=4))
516
+
517
+
518
+ if __name__ == "__main__":
519
+ # To run this script, you'll need to install tqdm:
520
+ # pip install tqdm
521
+ main()
main_i2i_ret.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ import traceback
8
+ import uuid
9
+ import multiprocessing
10
+ import concurrent.futures
11
+ from datetime import datetime
12
+ from functools import partial
13
+
14
+ import requests
15
+ import torch
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from openai import AzureOpenAI, OpenAI
19
+ from volcenginesdkarkruntime import Ark
20
+ from transformers import AutoModel, AutoProcessor
21
+ from torch.nn.functional import cosine_similarity
22
+
23
+ # --- Model and Configuration Constants ---
24
+
25
+ # SigLIP model for generating image embeddings
26
+ SIGLIP_MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
27
+ # Number of Top-K frames to retrieve for each generated image
28
+ TOP_K_FRAMES = 8
29
+
30
+ # --- Prompt Templates ---
31
+
32
+ # Step 1: System prompt for VLM to analyze video and question, then generate image creation requests
33
+ # The goal of this prompt is not to answer the question, but to plan which keyframes need to be "seen"
34
+ STEP_1_PLANNING_PROMPT = """
35
+ You are a professional video analyst. Your task is to analyze a question and a few initial video sample frames, then plan what keyframes you need to see to answer the question.
36
+
37
+ Do not answer the question directly. Your output must be a JSON array, where each object represents a keyframe you wish to generate.
38
+ Each object must contain the following two keys:
39
+ 1. `reference_image_id`: An integer representing the ID of a frame already provided to you that you wish to use as a generation reference. This ID must be one of the IDs provided by the user.
40
+ 2. `prompt`: A detailed text description to tell the image generation model what kind of scene to draw.
41
+
42
+ For example, if the question is "Where did the man in the red shirt eventually go?", you might generate the following JSON:
43
+ ```json
44
+ [
45
+ {
46
+ "reference_image_id": 120,
47
+ "prompt": "A man in a red shirt is walking towards an open door, with a background similar to the reference image."
48
+ },
49
+ {
50
+ "reference_image_id": 120,
51
+ "prompt": "A man in a red shirt has already walked out the door, and the door is closing, with a background similar to the reference image."
52
+ }
53
+ ]
54
+ ```
55
+ Your output must strictly adhere to this JSON format.
56
+ """
57
+
58
+ # Step 3: System prompt for VLM to perform final reasoning and answer based on all retrieved keyframes
59
+ STEP_3_FINAL_ANSWER_PROMPT = """
60
+ You are an AI video question-answering assistant.
61
+ The user will provide you with a series of keyframes retrieved from a video and a question.
62
+
63
+ First, please provide a step-by-step reasoning process, analyzing these keyframes and deriving your conclusion.
64
+ After your reasoning, provide the final answer. The answer must be in a JSON code block, and the JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'.
65
+
66
+ Your output format must be strictly as follows:
67
+ <Your step-by-step reasoning process>
68
+ ```json
69
+ {"answer": "A"}
70
+ ```
71
+ Do not include any other text after the JSON block.
72
+ """
73
+
74
+
75
+ def parse_arguments():
76
+ """Parse command-line arguments"""
77
+ parser = argparse.ArgumentParser(
78
+ description="Image Retrieval-based Video QA Workflow"
79
+ )
80
+ # Model Configuration
81
+ parser.add_argument(
82
+ "--target-model", "-tm", type=str, required=True, help="VLM model for inference (e.g., gpt-4o)"
83
+ )
84
+ # Data Path Configuration
85
+ parser.add_argument(
86
+ "--frames-path", "-fp", type=str, required=True, help="Root directory containing video frame folders"
87
+ )
88
+ parser.add_argument(
89
+ "--data-file", "-df", type=str, required=True, help="JSON data file containing evaluation questions"
90
+ )
91
+ parser.add_argument(
92
+ "--embeddings-path", "-ep", type=str, required=True, help="Directory containing pre-computed embeddings for all video frames"
93
+ )
94
+ parser.add_argument(
95
+ "--output-path", "-op", type=str, default="./results_image_retrieval", help="Directory to store all outputs and generated images"
96
+ )
97
+ # Workflow Parameters
98
+ parser.add_argument(
99
+ "--initial-frames-num", "-ifn", type=int, default=8, help="Number of initial uniformly sampled frames for Step 1"
100
+ )
101
+ # Execution Configuration
102
+ parser.add_argument(
103
+ "--max-retry-times", "-mr", type=int, default=10, help="Maximum number of retries for API calls"
104
+ )
105
+ parser.add_argument(
106
+ "--pool-processes", "-pp", type=int, default=10, help="Number of parallel processes"
107
+ )
108
+ # API Credentials
109
+ parser.add_argument(
110
+ "--base_url", type=str, required=True, help="API Endpoint URL for the VLM model"
111
+ )
112
+ parser.add_argument(
113
+ "--api_key", type=str, required=True, help="API Key for the VLM model"
114
+ )
115
+ return parser.parse_args()
116
+
117
+
118
+ def save_json_file(data, output_file):
119
+ """Save data to a JSON file"""
120
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
121
+ with open(output_file, "w", encoding="utf-8") as f:
122
+ json.dump(data, f, indent=4, ensure_ascii=False)
123
+
124
+
125
+ def extract_json_from_response(response, is_list=False):
126
+ """Extract a JSON object or list from the model's response text"""
127
+ if not response:
128
+ return None
129
+ # The regular expression supports both JSON objects `{...}` and lists `[...]`
130
+ pattern = r"```json\s*([\{\[].*?[\]\}])\s*```"
131
+ match = re.search(pattern, response, re.DOTALL)
132
+ if match:
133
+ json_str = match.group(1)
134
+ try:
135
+ return json.loads(json_str)
136
+ except json.JSONDecodeError:
137
+ print(f"JSON parsing failed: {json_str}")
138
+ return None
139
+ return None
140
+
141
+
142
+ def calculate_metrics(results):
143
+ """Calculate accuracy and other metrics from evaluation results"""
144
+ valid_results = [r for r in results if "error" not in r]
145
+ total_samples = len(valid_results)
146
+ if total_samples == 0: return {"accuracy": 0.0}
147
+
148
+ answered = sum(1 for x in valid_results if x.get("model_answer") is not None)
149
+ correct = sum(1 for x in valid_results if x.get("is_correct"))
150
+ accuracy = correct / answered if answered > 0 else 0.0
151
+
152
+ return {
153
+ "total_samples": total_samples,
154
+ "answered_samples": answered,
155
+ "correct_answers": correct,
156
+ "accuracy": accuracy,
157
+ }
158
+
159
+
160
+ def call_vlm_api(client, messages, model, item_id, max_retry_times, json_schema=None):
161
+ """Call VLM API, with support for retries and structured output"""
162
+ params = {"model": model, "messages": messages, "max_tokens": 4096}
163
+ if json_schema:
164
+ params["response_format"] = {"type": "json_object", "schema": json_schema}
165
+
166
+ for retry in range(max_retry_times):
167
+ try:
168
+ completion = client.chat.completions.create(**params)
169
+ return completion.choices[0].message.content
170
+ except Exception as e:
171
+ print(f"API Error (item {item_id}): {e}. Retrying ({retry + 1}/{max_retry_times})...")
172
+ if retry == max_retry_times - 1:
173
+ raise e
174
+ time.sleep(5)
175
+
176
+
177
+ def generate_image(reference_image_id, prompt, all_frame_paths, output_dir, generation_idx):
178
+ """Call the image generation API to create a new frame"""
179
+ print(f"\n[Image Generation] Using Prompt: '{prompt}'")
180
+ ark_api_key = os.environ.get("ARK_API_KEY")
181
+ if not ark_api_key:
182
+ raise ValueError("Environment variable ARK_API_KEY is not set.")
183
+
184
+ client = Ark(base_url="https://ark.cn-beijing.volces.com/api/v3", api_key=ark_api_key)
185
+
186
+ ref_image_path = all_frame_paths.get(reference_image_id)
187
+ if not ref_image_path or not os.path.exists(ref_image_path):
188
+ raise FileNotFoundError(f"Reference image ID {reference_image_id} not found.")
189
+
190
+ try:
191
+ ref_image_b64 = encode_image(ref_image_path)
192
+ ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}"
193
+
194
+ response = client.images.generate(
195
+ model="doubao-seedream-4-0-250828",
196
+ prompt=prompt,
197
+ image=ref_image_data_uri,
198
+ size="1024x1024",
199
+ response_format="url",
200
+ watermark=False,
201
+ )
202
+ image_url = response.data[0].url
203
+
204
+ image_content = requests.get(image_url, timeout=60).content
205
+
206
+ new_frame_filename = f"generated_frame_{generation_idx}_ref_{reference_image_id}.jpg"
207
+ new_frame_path = os.path.join(output_dir, new_frame_filename)
208
+
209
+ with open(new_frame_path, "wb") as f:
210
+ f.write(image_content)
211
+
212
+ print(f"[Image Generation Success] Image saved to: {new_frame_path}")
213
+ return new_frame_path
214
+ except Exception as e:
215
+ print(f"Image generation or download failed: {e}")
216
+ traceback.print_exc()
217
+ return None
218
+
219
+ def retrieve_frames_by_image_embedding(
220
+ image_path, video_embeddings_data, request_queue, results_dict, k
221
+ ):
222
+ """Retrieve Top-K similar frames from the video using an image embedding"""
223
+ device = "cuda" if torch.cuda.is_available() else "cpu"
224
+ frame_filenames = video_embeddings_data["filenames"]
225
+ frame_embeddings = video_embeddings_data["embeddings"].to(device)
226
+
227
+ # 1. Send request to the embedding server process
228
+ request_id = str(uuid.uuid4())
229
+ request_queue.put((request_id, image_path))
230
+
231
+ # 2. Wait for the result
232
+ while request_id not in results_dict:
233
+ time.sleep(0.05)
234
+ query_embedding = results_dict.pop(request_id).to(device)
235
+
236
+ # 3. Calculate similarity and find Top-K frames
237
+ with torch.no_grad():
238
+ similarities = cosine_similarity(query_embedding, frame_embeddings)
239
+ top_k_indices = torch.topk(similarities, k=min(k, len(frame_filenames)), dim=-1).indices.cpu()
240
+
241
+ # Extract absolute paths for the frames from the filenames
242
+ video_frame_dir = os.path.dirname(frame_filenames[0])
243
+ top_k_paths = [os.path.join(video_frame_dir, video_embeddings_data['filenames'][i]) for i in top_k_indices]
244
+
245
+ return top_k_paths
246
+
247
+ def embedding_server_process(model_id, device, request_queue, results_dict):
248
+ """
249
+ An independent server process that loads the SigLIP model and handles image embedding requests from worker processes.
250
+ """
251
+ print(f"Embedding server started (PID: {os.getpid()})...")
252
+ model = AutoModel.from_pretrained(model_id).to(device).eval()
253
+ processor = AutoProcessor.from_pretrained(model_id)
254
+ print("SigLIP model loaded in the embedding server.")
255
+
256
+ while True:
257
+ try:
258
+ request_id, image_path = request_queue.get()
259
+ if image_path == "STOP":
260
+ print("Embedding server received stop signal, shutting down.")
261
+ break
262
+
263
+ with torch.no_grad():
264
+ image = Image.open(image_path).convert("RGB")
265
+ inputs = processor(images=[image], return_tensors="pt").to(device)
266
+ image_features = model.get_image_features(**inputs)
267
+ results_dict[request_id] = image_features.cpu()
268
+
269
+ except Exception as e:
270
+ print(f"Error in embedding server: {e}")
271
+ traceback.print_exc()
272
+
273
+
274
+ def encode_image(image_path):
275
+ """Encode an image file to a Base64 string"""
276
+ with open(image_path, "rb") as f:
277
+ return base64.b64encode(f.read()).decode("utf-8")
278
+
279
+
280
+ def uniformly_sample_frames_and_encode(frames_dir, num_frames):
281
+ """Uniformly sample frames and encode them, while also returning a mapping of frame IDs to paths"""
282
+ if not os.path.isdir(frames_dir): return [], {}
283
+
284
+ frame_files = sorted(
285
+ [f for f in os.listdir(frames_dir) if f.endswith(".jpg")],
286
+ key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
287
+ )
288
+ if not frame_files: return [], {}
289
+
290
+ indices = [int(i * len(frame_files) / num_frames) for i in range(num_frames)]
291
+ sampled_files = [frame_files[i] for i in indices]
292
+
293
+ frame_path_map, encoded_frames = {}, []
294
+ for f in sampled_files:
295
+ path = os.path.join(frames_dir, f)
296
+ frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1))
297
+
298
+ encoded_frames.extend([
299
+ {"type": "text", "text": f"This is Frame ID: {frame_id}"},
300
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}
301
+ ])
302
+ frame_path_map[frame_id] = path
303
+ return encoded_frames, frame_path_map
304
+
305
+
306
+ def run_workflow_for_item(
307
+ data_item, args, request_queue, results_dict
308
+ ):
309
+ """Execute the complete three-step workflow for a single data item"""
310
+ item_key = data_item["key"]
311
+ print(f"\n--- Starting processing for video: {item_key} ---")
312
+
313
+ # Create a separate output directory for each video's generated images
314
+ generated_images_dir = os.path.join(args.output_path, "generated_images", item_key)
315
+ os.makedirs(generated_images_dir, exist_ok=True)
316
+
317
+ # Initialize VLM client
318
+ if "ark" in args.base_url:
319
+ client = Ark(base_url=args.base_url, api_key=args.api_key)
320
+ elif "aliyun" in args.base_url or "127.0.0.1" in args.base_url:
321
+ client = OpenAI(api_key=args.api_key, base_url=args.base_url)
322
+ else:
323
+ client = AzureOpenAI(api_version="2023-05-15", api_key=args.api_key, azure_endpoint=args.base_url)
324
+
325
+ # --- Step 1: Initial understanding and generating "keyframe profile" requests ---
326
+ print(f"[{item_key}] Step 1: Uniformly sampling and generating keyframe creation requests...")
327
+ video_frames_path = os.path.join(args.frames_path, item_key)
328
+ initial_frames_encoded, initial_frame_paths = uniformly_sample_frames_and_encode(
329
+ video_frames_path, args.initial_frames_num
330
+ )
331
+ if not initial_frames_encoded:
332
+ raise FileNotFoundError(f"Initial frames not found for video {item_key}.")
333
+
334
+ planning_messages = [
335
+ {"role": "system", "content": STEP_1_PLANNING_PROMPT},
336
+ {"role": "user", "content": [
337
+ {"type": "text", "text": "Here are the initial sample frames and the question:"},
338
+ *initial_frames_encoded,
339
+ {"type": "text", "text": f"Question: {data_item['question']}"}
340
+ ]}
341
+ ]
342
+
343
+ # Define JSON Schema for structured output
344
+ planning_schema = {
345
+ "type": "array",
346
+ "items": {
347
+ "type": "object",
348
+ "properties": {
349
+ "reference_image_id": {"type": "integer"},
350
+ "prompt": {"type": "string"}
351
+ },
352
+ "required": ["reference_image_id", "prompt"]
353
+ }
354
+ }
355
+
356
+ raw_planning_response = call_vlm_api(client, planning_messages, args.target_model, item_key, args.max_retry_times)
357
+ image_generation_requests = extract_json_from_response(raw_planning_response, is_list=True)
358
+
359
+ if not image_generation_requests or not isinstance(image_generation_requests, list):
360
+ raise ValueError(f"Step 1 failed to generate valid JSON-formatted image generation requests. Response: {raw_planning_response}")
361
+
362
+ print(f"[{item_key}] Successfully generated {len(image_generation_requests)} keyframe generation requests.")
363
+
364
+ # --- Validate and correct reference image IDs ---
365
+ valid_ids = list(initial_frame_paths.keys())
366
+ if not valid_ids:
367
+ raise ValueError(f"No valid initial frame IDs found for video {item_key}.")
368
+
369
+ for req in image_generation_requests:
370
+ original_id = req.get("reference_image_id")
371
+ if original_id not in valid_ids:
372
+ closest_id = min(valid_ids, key=lambda valid_id: abs(valid_id - original_id))
373
+ print(f"Warning: Model generated a non-existent reference_image_id: {original_id}. Substituting with the closest valid ID: {closest_id}.")
374
+ req["reference_image_id"] = closest_id
375
+
376
+ # --- Step 2: Generate images and perform similarity retrieval ---
377
+ print(f"[{item_key}] Step 2: Generating images and retrieving similar frames...")
378
+ all_retrieved_frame_paths = set()
379
+ generated_image_paths = []
380
+ video_embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt")
381
+ if not os.path.exists(video_embedding_file):
382
+ raise FileNotFoundError(f"Embedding file for video {item_key} not found: {video_embedding_file}")
383
+ video_embeddings_data = torch.load(video_embedding_file, map_location="cpu")
384
+
385
+ # Correct path issue, ensure filenames in the embedding file are absolute paths
386
+ video_frame_dir_for_embeddings = os.path.join(args.frames_path, item_key)
387
+ video_embeddings_data['filenames'] = [os.path.join(video_frame_dir_for_embeddings, os.path.basename(f)) for f in video_embeddings_data['filenames']]
388
+
389
+
390
+ for i, req in enumerate(image_generation_requests):
391
+ # 2a. Generate image
392
+ generated_path = generate_image(
393
+ reference_image_id=req["reference_image_id"],
394
+ prompt=req["prompt"],
395
+ all_frame_paths=initial_frame_paths,
396
+ output_dir=generated_images_dir,
397
+ generation_idx=i + 1,
398
+ )
399
+
400
+ path_for_retrieval = None
401
+ if generated_path:
402
+ generated_image_paths.append(generated_path)
403
+ path_for_retrieval = generated_path
404
+ else:
405
+ print(f"Warning: Generation failed for image {i+1}. Using its reference image (ID: {req['reference_image_id']}) for retrieval instead.")
406
+ path_for_retrieval = initial_frame_paths.get(req["reference_image_id"])
407
+
408
+ if not path_for_retrieval:
409
+ print(f"Error: Could not find a path for retrieval for request {i+1}. Skipping.")
410
+ continue
411
+
412
+ # 2b. Retrieve frames via image embedding
413
+ retrieved_paths = retrieve_frames_by_image_embedding(
414
+ path_for_retrieval, video_embeddings_data, request_queue, results_dict, k=TOP_K_FRAMES
415
+ )
416
+ all_retrieved_frame_paths.update(retrieved_paths)
417
+ print(f"[{item_key}] Retrieval {i+1}/{len(image_generation_requests)} complete, found {len(retrieved_paths)} frames.")
418
+
419
+ if not all_retrieved_frame_paths:
420
+ raise ValueError(f"Failed to retrieve any frames for video {item_key}.")
421
+
422
+ print(f"[{item_key}] Step 2 complete. Retrieved a total of {len(all_retrieved_frame_paths)} unique keyframes.")
423
+
424
+ # --- Step 3: Consolidate keyframes for final reasoning ---
425
+ print(f"[{item_key}] Step 3: Consolidating keyframes for final reasoning...")
426
+ final_frames_encoded = []
427
+ for path in sorted(list(all_retrieved_frame_paths)):
428
+ final_frames_encoded.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}})
429
+
430
+ final_messages = [
431
+ {"role": "system", "content": STEP_3_FINAL_ANSWER_PROMPT},
432
+ {"role": "user", "content": [
433
+ {"type": "text", "text": "Here are all the keyframes retrieved for you. Please answer the question based on them."},
434
+ *final_frames_encoded,
435
+ {"type": "text", "text": f"Question: {data_item['question']}"}
436
+ ]}
437
+ ]
438
+
439
+ final_response_text = call_vlm_api(client, final_messages, args.target_model, item_key, args.max_retry_times)
440
+
441
+ # --- Consolidating Results ---
442
+ parsed_answer = extract_json_from_response(final_response_text)
443
+ model_answer = parsed_answer.get("answer", "").strip().upper() if parsed_answer else None
444
+ is_correct = (model_answer == data_item["answer"].strip().upper()) if model_answer else False
445
+
446
+ result = {
447
+ **data_item,
448
+ "workflow_steps": {
449
+ "step1_planning_requests": image_generation_requests,
450
+ "step2_generated_images": generated_image_paths,
451
+ "step2_retrieved_frame_paths": sorted(list(all_retrieved_frame_paths)),
452
+ "step3_final_reasoning_and_answer": final_response_text,
453
+ },
454
+ "model_answer": model_answer,
455
+ "is_correct": is_correct,
456
+ }
457
+ return result
458
+
459
+
460
+ def process_single_data_wrapper(data_item, args, request_queue, results_dict):
461
+ """Wrapper function to process a single data item, used for exception handling"""
462
+ try:
463
+ return run_workflow_for_item(data_item, args, request_queue, results_dict)
464
+ except Exception as e:
465
+ print(f"\nA critical error occurred while processing video {data_item['key']}: {e}")
466
+ traceback.print_exc()
467
+ return {
468
+ "key": data_item['key'],
469
+ "uid": data_item.get('uid'),
470
+ "error": str(e),
471
+ "traceback": traceback.format_exc(),
472
+ }
473
+
474
+ def main():
475
+ """Main function to orchestrate the entire evaluation workflow"""
476
+ args = parse_arguments()
477
+ print("--- Image Retrieval-based Video QA Workflow Starting ---")
478
+ print(f"Evaluating Model: {args.target_model}, Dataset: {args.data_file}")
479
+
480
+ try:
481
+ multiprocessing.set_start_method("spawn", force=True)
482
+ except RuntimeError:
483
+ pass # Start method already set
484
+
485
+ os.makedirs(args.output_path, exist_ok=True)
486
+
487
+ # Define output file paths
488
+ model_safe_name = args.target_model.replace("/", "_")
489
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
490
+ output_prefix = f"{model_safe_name}_{data_filename_base}_image_retrieval_{args.initial_frames_num}frames"
491
+
492
+ results_file = os.path.join(args.output_path, f"{output_prefix}_results.json")
493
+ metrics_file = os.path.join(args.output_path, f"{output_prefix}_metrics.json")
494
+
495
+ test_data = load_test_data(args.data_file)
496
+ all_results = []
497
+
498
+ with multiprocessing.Manager() as manager:
499
+ request_queue = manager.Queue()
500
+ results_dict = manager.dict()
501
+
502
+ device = "cuda" if torch.cuda.is_available() else "cpu"
503
+ embedding_server = multiprocessing.Process(
504
+ target=embedding_server_process,
505
+ args=(SIGLIP_MODEL_ID, device, request_queue, results_dict),
506
+ )
507
+ embedding_server.start()
508
+
509
+ # Wait for the embedding server model to load
510
+ time.sleep(15)
511
+
512
+ with concurrent.futures.ProcessPoolExecutor(max_workers=args.pool_processes) as executor:
513
+ func = partial(
514
+ process_single_data_wrapper,
515
+ args=args,
516
+ request_queue=request_queue,
517
+ results_dict=results_dict
518
+ )
519
+
520
+ results_iterator = executor.map(func, test_data)
521
+
522
+ for result in tqdm(results_iterator, total=len(test_data), desc="Processing Videos"):
523
+ if result:
524
+ all_results.append(result)
525
+ # Save results every 10 videos to prevent data loss from interruptions
526
+ if len(all_results) % 10 == 0:
527
+ save_json_file(all_results, results_file)
528
+
529
+ # Gracefully shut down the embedding server
530
+ print("All tasks completed. Shutting down the embedding server...")
531
+ request_queue.put((None, "STOP"))
532
+ embedding_server.join()
533
+
534
+ print("\n--- All Videos Processed ---")
535
+ save_json_file(all_results, results_file)
536
+ print(f"Detailed results saved to: {results_file}")
537
+
538
+ final_metrics = calculate_metrics(all_results)
539
+ save_json_file(final_metrics, metrics_file)
540
+ print(f"Final evaluation metrics saved to: {metrics_file}")
541
+ print(json.dumps(final_metrics, indent=4))
542
+
543
+
544
+ if __name__ == "__main__":
545
+ # Before running, please ensure you have set the API Key for the image generation service
546
+ # export ARK_API_KEY="YOUR_VOLCENGINE_ARK_API_KEY"
547
+ main()
548
+
main_mcot.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ import traceback
8
+ from datetime import datetime
9
+ from functools import partial
10
+ import requests # Import requests library to download images from URLs
11
+ from openai import AzureOpenAI, OpenAI
12
+ from volcenginesdkarkruntime import Ark
13
+ import concurrent.futures
14
+ from tqdm import tqdm
15
+
16
+ # 1. New Agent System Prompt
17
+ # Defines the agent's role and principles, guiding it to use the "imagination" tool when visual evidence is insufficient.
18
+ IMAGINE_AGENT_SYSTEM_PROMPT = """
19
+ You are an intelligent AI assistant specializing in answering video question-answering problems through reasoning and imagination.
20
+ Your task is to answer a multiple-choice question based on an initial, limited set of video frames.
21
+
22
+ You will receive a few uniformly sampled frames to get a basic understanding of the video.
23
+ These frames may not contain all the visual evidence needed to directly answer the question.
24
+
25
+ If the provided frame information is insufficient, you must use the `imagine_frame` tool to generate new, imagined frames to fill in the visual gaps and aid your reasoning.
26
+ You can call this tool multiple times to construct a sequence of imagined events.
27
+
28
+ Your strategy should be:
29
+ 1. Analyze the initial frames and the user's question.
30
+ 2. Form a hypothesis about the missing content.
31
+ 3. If you need more visual information, call the `imagine_frame` tool. Provide a text `prompt` describing the scene you want to imagine, and select a `reference_image_id` from existing frames. The `reference_image_id` MUST be one of the IDs explicitly provided to you in the conversation history (e.g., "Frame ID: X" or "New Frame ID: Y"). Do not invent or assume frame IDs.
32
+ 4. Analyze the newly generated frame in conjunction with the existing ones.
33
+ 5. Continue this process of reasoning and imagination until you are confident in your answer. Please ensure you have found or created the relevant visual cues before answering the question.
34
+ 6. Each tool call can only generate one frame.
35
+
36
+ IMPORTANT: Your text `prompt` for image generation must be safe and general. Avoid descriptions that could be interpreted as sensitive, harmful, or explicit to prevent generation failures.
37
+
38
+ After your reasoning, provide the final answer in a JSON code block. The JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'.
39
+
40
+ Your output must strictly follow this format:
41
+ <Your step-by-step reasoning process here, including why you chose to imagine a certain frame>
42
+ ```json
43
+ {"answer": "X"}
44
+ ```
45
+ Do not include any other text after the JSON code block.
46
+ """
47
+
48
+ # 2. New Tool Schema for imagine_frame
49
+ # Defines the interface, parameters, and description for the `imagine_frame` tool.
50
+ IMAGINE_FRAME_TOOL_SCHEMA = {
51
+ "type": "function",
52
+ "function": {
53
+ "name": "imagine_frame",
54
+ "description": "When visual evidence is insufficient, generates a new image based on a text prompt and a reference image to help answer the question. Use it to imagine what might have happened between the provided frames.",
55
+ "parameters": {
56
+ "type": "object",
57
+ "properties": {
58
+ "reference_image_id": {
59
+ "type": "integer",
60
+ "description": "The ID of an existing frame to use as a style and content reference. It can be one of the original frames or a previously generated one.",
61
+ },
62
+ "prompt": {
63
+ "type": "string",
64
+ "description": "A detailed text description of the frame you want to imagine and generate.",
65
+ },
66
+ },
67
+ "required": ["reference_image_id", "prompt"],
68
+ },
69
+ },
70
+ }
71
+
72
+
73
+ # 3. Implementation of the `imagine_frame` tool
74
+ def imagine_frame(
75
+ reference_image_id: int,
76
+ prompt: str,
77
+ all_frame_paths: dict,
78
+ output_dir: str,
79
+ generation_count: int,
80
+ ):
81
+ """
82
+ Tool implementation: Calls an image generation model to create a new frame.
83
+
84
+ Args:
85
+ reference_image_id (int): The ID of the reference frame.
86
+ prompt (str): The text prompt for image generation.
87
+ all_frame_paths (dict): A dictionary containing IDs and paths of all currently available frames (original + generated).
88
+ output_dir (str): The directory to save the generated image.
89
+ generation_count (int): The current generation count, used for naming the file.
90
+
91
+ Returns:
92
+ str or None: The path of the newly generated image on success, otherwise None.
93
+ """
94
+ print(f"\n[Tool Call] Imagining new frame with prompt: '{prompt}'")
95
+ ark_api_key = os.environ.get("ARK_API_KEY")
96
+ if not ark_api_key:
97
+ raise ValueError("Error: Environment variable ARK_API_KEY is not set.")
98
+
99
+ client = Ark(
100
+ base_url="https://ark.cn-beijing.volces.com/api/v3",
101
+ api_key=ark_api_key,
102
+ )
103
+
104
+ ref_image_path = all_frame_paths.get(reference_image_id)
105
+ if not ref_image_path or not os.path.exists(ref_image_path):
106
+ raise FileNotFoundError(f"Reference image ID not found: {reference_image_id}")
107
+
108
+ try:
109
+ # Encode the reference image to a Base64 Data URI
110
+ ref_image_b64 = encode_image(ref_image_path)
111
+ ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}"
112
+
113
+ imagesResponse = client.images.generate(
114
+ model="doubao-seedream-4-0-250828",
115
+ prompt=prompt,
116
+ image=ref_image_data_uri,
117
+ size="1024x1024", # Can be adjusted as needed, e.g., "2K"
118
+ response_format="url",
119
+ watermark=False,
120
+ )
121
+
122
+ image_url = imagesResponse.data[0].url
123
+
124
+ # Download the image from the URL
125
+ response = requests.get(image_url)
126
+ response.raise_for_status()
127
+
128
+ # Save the image to the specified directory
129
+ new_frame_filename = (
130
+ f"generated_frame_{generation_count}_ref_{reference_image_id}.jpg"
131
+ )
132
+ new_frame_path = os.path.join(output_dir, new_frame_filename)
133
+
134
+ with open(new_frame_path, "wb") as f:
135
+ f.write(response.content)
136
+
137
+ print(f"[Tool Success] Generated frame saved to: {new_frame_path}")
138
+ return new_frame_path
139
+
140
+ except Exception as e:
141
+ print(f"An error occurred during image generation or download: {e}")
142
+ traceback.print_exc()
143
+ return None
144
+
145
+
146
+ def parse_arguments():
147
+ """Parse command-line arguments"""
148
+ parser = argparse.ArgumentParser(
149
+ description="Video QA Evaluation Framework with Imagine-and-Reason Agent"
150
+ )
151
+ parser.add_argument(
152
+ "--target-model",
153
+ "-tm",
154
+ type=str,
155
+ required=True,
156
+ help="The model to be evaluated (e.g., gpt-4o)",
157
+ )
158
+ parser.add_argument(
159
+ "--frames-path",
160
+ "-fp",
161
+ type=str,
162
+ required=True,
163
+ help="Absolute path to the root directory containing video frames.",
164
+ )
165
+ parser.add_argument(
166
+ "--output-path",
167
+ "-op",
168
+ type=str,
169
+ default="./generated_outputs",
170
+ help="Path to store generated images and results.",
171
+ )
172
+ parser.add_argument(
173
+ "--data-file",
174
+ "-df",
175
+ type=str,
176
+ required=True,
177
+ help="Absolute path to the evaluation dataset JSON file.",
178
+ )
179
+ parser.add_argument(
180
+ "--initial-frames-num",
181
+ "-ifn",
182
+ type=int,
183
+ default=8,
184
+ help="Number of initial uniformly sampled frames.",
185
+ )
186
+ parser.add_argument(
187
+ "--max-retry-times",
188
+ "-mr",
189
+ type=int,
190
+ default=10,
191
+ help="Maximum number of retries for failed API calls.",
192
+ )
193
+ parser.add_argument(
194
+ "--pool-processes",
195
+ "-pp",
196
+ type=int,
197
+ default=10,
198
+ help="Number of parallel processes.",
199
+ )
200
+ parser.add_argument(
201
+ "--base_url",
202
+ type=str,
203
+ required=True,
204
+ help="API Endpoint URL for the target model service.",
205
+ )
206
+ parser.add_argument(
207
+ "--api_key",
208
+ type=str,
209
+ required=True,
210
+ help="API Key for the target model service.",
211
+ )
212
+ return parser.parse_args()
213
+
214
+
215
+ def save_json_file(data, output_file):
216
+ """Save data to a JSON file"""
217
+ with open(output_file, "w", encoding="utf-8") as f:
218
+ json.dump(data, f, indent=4, ensure_ascii=False)
219
+
220
+
221
+ def extract_json_from_response(response):
222
+ """Extract JSON answer from the model's text response"""
223
+ if not response:
224
+ return None
225
+ match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
226
+ if match:
227
+ try:
228
+ return json.loads(match.group(1))
229
+ except (json.JSONDecodeError, IndexError):
230
+ return None
231
+ return None
232
+
233
+
234
+ def calculate_metrics(results):
235
+ """Calculate various metrics from the evaluation results"""
236
+ valid_results = [r for r in results if "error" not in r]
237
+ total_samples = len(valid_results)
238
+ if total_samples == 0:
239
+ return {
240
+ "total_samples": 0,
241
+ "answered_samples": 0,
242
+ "correct_answers": 0,
243
+ "accuracy": 0.0,
244
+ }
245
+ answered_samples = sum(
246
+ 1 for x in valid_results if x.get("model_answer") is not None
247
+ )
248
+ correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
249
+ accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
250
+ return {
251
+ "total_samples": total_samples,
252
+ "answered_samples": answered_samples,
253
+ "correct_answers": correct_answers,
254
+ "accuracy": accuracy,
255
+ }
256
+
257
+
258
+ def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
259
+ """A single model API call with retry logic"""
260
+ params = {"model": model, "messages": messages, "max_tokens": 4096}
261
+ if tools:
262
+ params["tools"] = tools
263
+ params["tool_choice"] = "auto"
264
+
265
+ retry_times = 0
266
+ while retry_times < max_retry_times:
267
+ try:
268
+ completion = client.chat.completions.create(**params)
269
+ return completion.choices[0].message
270
+ except Exception as e:
271
+ retry_times += 1
272
+ print(
273
+ f"API call error (Item {item_id}): {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
274
+ )
275
+ if retry_times == max_retry_times:
276
+ raise e
277
+ time.sleep(5)
278
+
279
+
280
+ def uniformly_sample_frames_and_encode(frames_dir, num_frames):
281
+ """Uniformly sample a specified number of frames from a directory and encode them"""
282
+ if not os.path.isdir(frames_dir):
283
+ return [], {}
284
+
285
+ frame_files = sorted(
286
+ [f for f in os.listdir(frames_dir) if f.endswith(".jpg")],
287
+ key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
288
+ )
289
+
290
+ total_frames = len(frame_files)
291
+ if total_frames == 0:
292
+ return [], {}
293
+
294
+ if total_frames > num_frames:
295
+ indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
296
+ sampled_files = [frame_files[i] for i in indices]
297
+ else:
298
+ sampled_files = frame_files
299
+
300
+ frame_path_map = {}
301
+ encoded_frames = []
302
+ for f in sampled_files:
303
+ path = os.path.join(frames_dir, f)
304
+ frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1))
305
+ b64_image = encode_image(path)
306
+ # Send frame ID and image content as a pair
307
+ encoded_frames.append({"type": "text", "text": f"This is Frame ID: {frame_id}"})
308
+ encoded_frames.append(
309
+ {
310
+ "type": "image_url",
311
+ "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
312
+ }
313
+ )
314
+ frame_path_map[frame_id] = path
315
+
316
+ return encoded_frames, frame_path_map
317
+
318
+
319
+ def evaluate_single_item_agentic_imagination(
320
+ data_item,
321
+ initial_frames,
322
+ initial_frame_paths,
323
+ generated_images_dir,
324
+ target_model,
325
+ api_key,
326
+ base_url,
327
+ max_retry_times,
328
+ ):
329
+ """
330
+ Core logic for evaluating a single data item using the Imagine-and-Reason Agent.
331
+ """
332
+ # 4. New Agent Loop
333
+ if "ark" in base_url:
334
+ client = Ark(base_url=base_url, api_key=api_key)
335
+ elif "aliyun" in base_url or "127.0.0.1" in base_url:
336
+ client = OpenAI(api_key=api_key, base_url=base_url)
337
+ else:
338
+ client = AzureOpenAI(
339
+ api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
340
+ )
341
+
342
+ tools = [IMAGINE_FRAME_TOOL_SCHEMA]
343
+
344
+ # Store paths of all available frames (initial + generated) in a dictionary for reference
345
+ available_frame_paths = initial_frame_paths.copy()
346
+
347
+ initial_prompt_content = [
348
+ {
349
+ "type": "text",
350
+ "text": "Here are the initial sampled video frames provided to you:",
351
+ },
352
+ *initial_frames,
353
+ {
354
+ "type": "text",
355
+ "text": f"Please answer the following question:\n{data_item['question']}",
356
+ },
357
+ ]
358
+
359
+ messages = [
360
+ {"role": "system", "content": IMAGINE_AGENT_SYSTEM_PROMPT},
361
+ {"role": "user", "content": initial_prompt_content},
362
+ ]
363
+
364
+ response_content = None
365
+ max_tool_calls = (
366
+ 5 # Limit the number of times the agent can imagine to prevent infinite loops
367
+ )
368
+ generation_count = 0
369
+
370
+ for i in range(max_tool_calls):
371
+ response_message = call_single_model(
372
+ client,
373
+ messages,
374
+ target_model,
375
+ data_item["key"],
376
+ max_retry_times,
377
+ tools=tools,
378
+ )
379
+ if response_message is None:
380
+ return None
381
+
382
+ messages.append(response_message.model_dump(exclude_none=True))
383
+
384
+ if response_message.tool_calls:
385
+ tool_call = response_message.tool_calls[
386
+ 0
387
+ ] # Process one tool call at a time
388
+ function_name = tool_call.function.name
389
+
390
+ if function_name == "imagine_frame":
391
+ generation_count += 1
392
+ function_args = json.loads(tool_call.function.arguments)
393
+ new_frame_path = imagine_frame(
394
+ **function_args,
395
+ all_frame_paths=available_frame_paths,
396
+ output_dir=generated_images_dir,
397
+ generation_count=generation_count,
398
+ )
399
+
400
+ if new_frame_path:
401
+ # Create a unique ID for the newly generated frame
402
+ new_frame_id = (
403
+ max(available_frame_paths.keys())
404
+ if available_frame_paths
405
+ else 0
406
+ ) + 1
407
+ available_frame_paths[new_frame_id] = new_frame_path
408
+
409
+ b64_image = encode_image(new_frame_path)
410
+ tool_response_content = [
411
+ {
412
+ "type": "text",
413
+ "text": f"Here is the frame you requested to imagine (New Frame ID: {new_frame_id}). Please use it to continue your reasoning.",
414
+ },
415
+ {
416
+ "type": "image_url",
417
+ "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
418
+ },
419
+ ]
420
+
421
+ messages.append(
422
+ {
423
+ "tool_call_id": tool_call.id,
424
+ "role": "tool",
425
+ "name": function_name,
426
+ "content": json.dumps(
427
+ {"status": "success", "new_frame_id": new_frame_id}
428
+ ),
429
+ }
430
+ )
431
+ messages.append({"role": "user", "content": tool_response_content})
432
+ else: # Tool execution failed
433
+ messages.append(
434
+ {
435
+ "tool_call_id": tool_call.id,
436
+ "role": "tool",
437
+ "name": function_name,
438
+ "content": json.dumps(
439
+ {
440
+ "status": "error",
441
+ "message": "Failed to generate image.",
442
+ }
443
+ ),
444
+ }
445
+ )
446
+ else: # No tool call means the model is ready to give a final answer
447
+ response_content = response_message.content
448
+ break
449
+
450
+ # If the max number of calls is reached without an answer, force a final response
451
+ if response_content is None and response_message:
452
+ final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far."
453
+ messages.append({"role": "user", "content": final_prompt})
454
+ final_response_message = call_single_model(
455
+ client, messages, target_model, data_item["key"], max_retry_times
456
+ )
457
+ if final_response_message:
458
+ messages.append(final_response_message.model_dump(exclude_none=True))
459
+ response_content = final_response_message.content
460
+
461
+ is_correct = False
462
+ model_answer_cleaned = None
463
+ parsed_json = extract_json_from_response(response_content)
464
+ if parsed_json and "answer" in parsed_json:
465
+ model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
466
+ gold_answer = data_item["answer"].strip().upper()
467
+ if model_answer_cleaned == gold_answer:
468
+ is_correct = True
469
+
470
+ return {
471
+ **data_item,
472
+ "agent_conversation": messages,
473
+ "model_reasoning_and_answer": response_content,
474
+ "model_answer": model_answer_cleaned,
475
+ "is_correct": is_correct,
476
+ "generated_images_path": generated_images_dir, # 5. Store the path to intermediate generated images
477
+ }
478
+
479
+
480
+ def encode_image(image_path):
481
+ """Encode an image file to a Base64 string"""
482
+ with open(image_path, "rb") as image_file:
483
+ return base64.b64encode(image_file.read()).decode("utf-8")
484
+
485
+
486
+ def process_single_data(data_item, args):
487
+ """Worker function to process a single data item in parallel"""
488
+ item_key = data_item["key"]
489
+ try:
490
+ # Create a separate subfolder for each video's generated images
491
+ generated_images_dir = os.path.join(
492
+ args.output_path, "generated_images", item_key
493
+ )
494
+ os.makedirs(generated_images_dir, exist_ok=True)
495
+
496
+ specific_frames_path = os.path.join(args.frames_path, item_key)
497
+ initial_frames, initial_frame_paths = uniformly_sample_frames_and_encode(
498
+ specific_frames_path, args.initial_frames_num
499
+ )
500
+
501
+ if not initial_frames:
502
+ raise FileNotFoundError(f"Initial frames not found for item '{item_key}'")
503
+
504
+ result = evaluate_single_item_agentic_imagination(
505
+ data_item,
506
+ initial_frames,
507
+ initial_frame_paths,
508
+ generated_images_dir,
509
+ args.target_model,
510
+ args.api_key,
511
+ args.base_url,
512
+ args.max_retry_times,
513
+ )
514
+ return result
515
+
516
+ except Exception as e:
517
+ print(f"\nA critical error occurred while processing item {item_key}: {str(e)}")
518
+ traceback.print_exc()
519
+ return {
520
+ "key": item_key,
521
+ "uid": data_item.get("uid"),
522
+ "error": str(e),
523
+ "traceback": traceback.format_exc(),
524
+ }
525
+
526
+
527
+ def load_test_data(json_file):
528
+ """Load test data from a JSON file"""
529
+ try:
530
+ with open(json_file, "r", encoding="utf-8") as f:
531
+ return json.load(f)
532
+ except FileNotFoundError:
533
+ print(f"Error: Data file not found: {json_file}")
534
+ exit(1)
535
+ except json.JSONDecodeError:
536
+ print(f"Error: JSON file is malformed: {json_file}")
537
+ exit(1)
538
+
539
+
540
+ def main():
541
+ """Main function to orchestrate the entire evaluation flow"""
542
+ args = parse_arguments()
543
+
544
+ print("--- Video QA Imagine-and-Reason Agent Framework ---")
545
+ print(f"Evaluating Model: {args.target_model}")
546
+ print(f"Output Path: {args.output_path}")
547
+ print(f"Dataset: {args.data_file}")
548
+ print("---------------------------------")
549
+
550
+ # Create the main output directory
551
+ os.makedirs(args.output_path, exist_ok=True)
552
+
553
+ model_name_safe = args.target_model.replace("/", "_")
554
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
555
+
556
+ output_prefix = f"{model_name_safe}_{data_filename_base}_imagine_agent"
557
+ results_output_file = os.path.join(
558
+ args.output_path, f"{output_prefix}_results.json"
559
+ )
560
+ metrics_output_file = os.path.join(
561
+ args.output_path, f"{output_prefix}_metrics.json"
562
+ )
563
+ error_log_file = os.path.join(args.output_path, f"{output_prefix}_errors.log")
564
+
565
+ # The logic for resuming from a checkpoint can be added here, same as in the first script
566
+
567
+ all_test_data = load_test_data(args.data_file)
568
+ tasks_to_process = all_test_data
569
+
570
+ all_results = []
571
+ # Use ProcessPoolExecutor for parallel processing
572
+ with concurrent.futures.ProcessPoolExecutor(
573
+ max_workers=args.pool_processes
574
+ ) as executor:
575
+ func = partial(process_single_data, args=args)
576
+ results_iterator = executor.map(func, tasks_to_process)
577
+
578
+ for result in tqdm(
579
+ results_iterator, total=len(tasks_to_process), desc="Processing Videos"
580
+ ):
581
+ if result:
582
+ if "error" in result:
583
+ with open(error_log_file, "a", encoding="utf-8") as f:
584
+ f.write(
585
+ f"Error on item {result.get('key', 'N/A')}:\n Error: {result['error']}\n---\n"
586
+ )
587
+ all_results.append(result)
588
+
589
+ # Save results every 10 videos to prevent data loss from interruptions
590
+ if len(all_results) % 10 == 0:
591
+ save_json_file(all_results, results_output_file)
592
+
593
+ print("\n\nProcessing complete.")
594
+ # Save the final complete results
595
+ save_json_file(all_results, results_output_file)
596
+ print(f"Detailed results saved to: {results_output_file}")
597
+
598
+ # Calculate and save the final metrics
599
+ final_metrics = calculate_metrics(all_results)
600
+ save_json_file(final_metrics, metrics_output_file)
601
+ print(f"\nEvaluation metrics saved to: {metrics_output_file}")
602
+ print(json.dumps(final_metrics, indent=4))
603
+
604
+
605
+ if __name__ == "__main__":
606
+ # Before running this script, please ensure you have set the environment variable in your terminal:
607
+ # export ARK_API_KEY="YOUR_VOLCENGINE_ARK_API_KEY"
608
+ main()
main_new_agent.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ import traceback
8
+ from datetime import datetime
9
+ from functools import partial
10
+ from openai import AzureOpenAI, OpenAI
11
+ from volcenginesdkarkruntime import Ark
12
+ import concurrent.futures
13
+ from tqdm import tqdm
14
+ import torch
15
+ from transformers import AutoModel, AutoProcessor
16
+ from torch.nn.functional import cosine_similarity
17
+ # MODIFIED: Added imports for multiprocessing and uuid
18
+ import multiprocessing
19
+ import uuid
20
+
21
+ # --- Configuration for SigLIP Model ---
22
+ # MODIFIED: Updated to the local model path
23
+ SIGLIP_MODEL_ID = (
24
+ "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
25
+ )
26
+
27
+ # --- MODIFIED: Updated System Prompt explaining the two tools with examples ---
28
+ AGENT_SYSTEM_PROMPT = """
29
+ You are an intelligent AI assistant specialized in video question answering.
30
+ Your task is to answer a multiple-choice question based on a video by strategically retrieving and analyzing its frames.
31
+
32
+ You have two tools to retrieve frames. Both return images directly.
33
+
34
+ 1. `get_frames_by_id(frame_ids)`: Retrieves frames using their specific numerical IDs. Use this when the question provides direct temporal clues or when you need to view specific frames identified by another tool.
35
+ * **Example Use Case:** For a question like "What happens at the 1 minute 30 second mark?", you can calculate the approximate frame ID and use this tool to see the visual.
36
+ * **Example Use Case:** For "Describe the action in frame 550.", you would call this tool with `frame_ids=[550]`.
37
+
38
+ 2. `get_frames_by_similarity(query)`: Searches the entire video for frames that visually match a text description and returns the top 5 most relevant frames directly. Use this for content-based questions where the timing is unknown.
39
+ * **Example Use Case:** For a question like "What color is the main character's car?", you would use this tool with a query like "the main character's car".
40
+ * **Example Use Case:** For "Find the scene where a band is playing on stage", you would use the query "a band playing on stage".
41
+
42
+ Your strategy must be efficient:
43
+ 1. **Analyze the Query:** First, determine if the question is temporal/logical (better for `get_frames_by_id`) or content-based (requires `get_frames_by_similarity`).
44
+ 2. **Retrieve & Analyze:** Call the most appropriate tool. Analyze the returned frames to form a hypothesis.
45
+ 3. **Iterate:** If you need more information, refine your search query for the similarity tool or calculate new frame IDs for the ID tool and call again.
46
+ 4. **Final Answer:** Once you have gathered enough visual evidence, provide your step-by-step reasoning and then the final answer in the specified JSON format. Do not guess.
47
+
48
+ Your output should follow this format exactly:
49
+ <Your step-by-step reasoning here>
50
+ ```json
51
+ {"answer": "X"}
52
+ ```
53
+ Do not include any other text after the JSON block.
54
+ """
55
+
56
+ # Tool Schemas
57
+ GET_FRAMES_BY_ID_TOOL_SCHEMA = {
58
+ "type": "function",
59
+ "function": {
60
+ "name": "get_frames_by_id",
61
+ "description": "Retrieves specific video frames by their numerical IDs to get visual information.",
62
+ "parameters": {
63
+ "type": "object",
64
+ "properties": {
65
+ "frame_ids": {
66
+ "type": "array",
67
+ "items": {"type": "integer"},
68
+ "description": "A list of up to 10 frame numbers to retrieve.",
69
+ },
70
+ },
71
+ "required": ["frame_ids"],
72
+ },
73
+ },
74
+ }
75
+
76
+ GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA = {
77
+ "type": "function",
78
+ "function": {
79
+ "name": "get_frames_by_similarity",
80
+ "description": "Searches for and retrieves the top 5 most visually relevant frames for a given text query. Use this to locate visual content when frame numbers are unknown.",
81
+ "parameters": {
82
+ "type": "object",
83
+ "properties": {
84
+ "query": {
85
+ "type": "string",
86
+ "description": "A concise text description of the visual content to search for (e.g., 'a person playing piano').",
87
+ },
88
+ },
89
+ "required": ["query"],
90
+ },
91
+ },
92
+ }
93
+
94
+
95
+ def parse_arguments():
96
+ """Parse command line arguments."""
97
+ parser = argparse.ArgumentParser(
98
+ description="Agentic Video QA with Hybrid Frame Retrieval"
99
+ )
100
+ parser.add_argument(
101
+ "--target-model", "-tm", type=str, required=True, help="Model to evaluate."
102
+ )
103
+ parser.add_argument(
104
+ "--frames-path",
105
+ "-fp",
106
+ type=str,
107
+ required=True,
108
+ help="Base directory for video frames.",
109
+ )
110
+ parser.add_argument(
111
+ "--data-file",
112
+ "-df",
113
+ type=str,
114
+ required=True,
115
+ help="Path to the evaluation dataset.",
116
+ )
117
+ parser.add_argument(
118
+ "--embeddings-path",
119
+ "-ep",
120
+ type=str,
121
+ required=True,
122
+ help="Directory with pre-computed frame embeddings.",
123
+ )
124
+ parser.add_argument(
125
+ "--max-retry-times",
126
+ "-mr",
127
+ type=int,
128
+ default=10,
129
+ help="Max retries for API calls.",
130
+ )
131
+ parser.add_argument(
132
+ "--pool-processes",
133
+ "-pp",
134
+ type=int,
135
+ default=20,
136
+ help="Number of parallel processes.",
137
+ )
138
+ parser.add_argument("--base_url", type=str, required=True, help="API endpoint URL.")
139
+ parser.add_argument("--api_key", type=str, required=True, help="API key.")
140
+ return parser.parse_args()
141
+
142
+
143
+ def save_json_file(data, output_file):
144
+ """Saves data to a JSON file."""
145
+ with open(output_file, "w", encoding="utf-8") as f:
146
+ json.dump(data, f, indent=4)
147
+
148
+
149
+ def extract_json_from_response(response):
150
+ """Extracts a JSON object from a model's response string."""
151
+ if not response:
152
+ return None
153
+ match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
154
+ if match:
155
+ try:
156
+ return json.loads(match.group(1))
157
+ except (json.JSONDecodeError, IndexError):
158
+ return None
159
+ return None
160
+
161
+
162
+ def calculate_metrics(results):
163
+ """Calculates accuracy and other metrics from evaluation results."""
164
+ valid_results = [r for r in results if "error" not in r]
165
+ total_samples = len(valid_results)
166
+ if total_samples == 0:
167
+ return {
168
+ "total_samples": 0,
169
+ "answered_samples": 0,
170
+ "correct_answers": 0,
171
+ "accuracy": 0.0,
172
+ }
173
+ answered_samples = sum(
174
+ 1 for x in valid_results if x.get("model_answer") is not None
175
+ )
176
+ correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
177
+ accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
178
+ return {
179
+ "total_samples": total_samples,
180
+ "answered_samples": answered_samples,
181
+ "correct_answers": correct_answers,
182
+ "accuracy": accuracy,
183
+ }
184
+
185
+
186
+ def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
187
+ """Makes a single API call with retry logic and tool support."""
188
+ params = {"model": model, "messages": messages, "max_tokens": 4096}
189
+ if tools:
190
+ params["tools"] = tools
191
+ params["tool_choice"] = "auto"
192
+
193
+ for retry in range(max_retry_times):
194
+ try:
195
+ completion = client.chat.completions.create(**params)
196
+ return completion.choices[0].message
197
+ except Exception as e:
198
+ print(
199
+ f"API Error for item {item_id}: {str(e)}. Retrying ({retry + 1}/{max_retry_times})..."
200
+ )
201
+ if retry == max_retry_times - 1:
202
+ raise e
203
+ time.sleep(5)
204
+
205
+
206
+ def get_frames_by_id(frame_ids: list, all_frame_paths: list):
207
+ """Tool implementation: Retrieves and encodes frames from a list of IDs."""
208
+ retrieved_frames = []
209
+ frame_map = {
210
+ int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p
211
+ for p in all_frame_paths
212
+ if re.search(r"frame_(\d+)\.jpg", os.path.basename(p))
213
+ }
214
+ for fid in frame_ids:
215
+ path = frame_map.get(fid)
216
+ if path and os.path.exists(path):
217
+ b64_image = encode_image(path)
218
+ retrieved_frames.append(
219
+ {
220
+ "type": "image_url",
221
+ "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
222
+ }
223
+ )
224
+ return retrieved_frames
225
+
226
+
227
+ # MODIFIED: This function is now the "client" side of the embedding service.
228
+ def get_frames_by_similarity(
229
+ query: str,
230
+ all_frame_paths: list,
231
+ precomputed_data: dict,
232
+ request_queue: multiprocessing.Queue,
233
+ results_dict: dict,
234
+ k: int = 5,
235
+ ):
236
+ """
237
+ Requests a text embedding from the server process, calculates similarity,
238
+ finds top-k frames, and returns them encoded.
239
+ """
240
+ device = "cuda" if torch.cuda.is_available() else "cpu"
241
+ frame_filenames = precomputed_data["filenames"]
242
+ frame_embeddings = precomputed_data["embeddings"].to(device)
243
+
244
+ # 1. Send request to the embedding server process
245
+ request_id = str(uuid.uuid4())
246
+ request_queue.put((request_id, query))
247
+
248
+ # 2. Wait for the result
249
+ while request_id not in results_dict:
250
+ time.sleep(0.05)
251
+ query_embedding = results_dict.pop(request_id).to(device)
252
+
253
+ # 3. Perform similarity search with the received embedding
254
+ with torch.no_grad():
255
+ similarities = cosine_similarity(query_embedding, frame_embeddings)
256
+
257
+ num_frames_to_select = min(k, len(frame_filenames))
258
+ top_k_indices = (
259
+ torch.topk(similarities, k=num_frames_to_select, dim=-1)
260
+ .indices.cpu()
261
+ .flatten()
262
+ .numpy()
263
+ )
264
+
265
+ top_k_filenames = [frame_filenames[i] for i in top_k_indices]
266
+ top_k_frame_ids = [
267
+ int(re.search(r"frame_(\d+)\.jpg", f).group(1)) for f in top_k_filenames
268
+ ]
269
+
270
+ retrieved_frames = get_frames_by_id(top_k_frame_ids, all_frame_paths)
271
+ return retrieved_frames
272
+
273
+
274
+ def evaluate_single_item_agentic(
275
+ data_item,
276
+ all_frame_paths,
277
+ embeddings_data,
278
+ target_model,
279
+ api_key,
280
+ base_url,
281
+ max_retry_times,
282
+ request_queue, # MODIFIED: Added queue for IPC
283
+ results_dict, # MODIFIED: Added dict for IPC
284
+ ):
285
+ """Evaluates a single item using an agentic loop."""
286
+ if "ark" in base_url:
287
+ client = Ark(base_url=base_url, api_key=api_key)
288
+ elif "aliyun" in base_url or "127.0.0.1" in base_url:
289
+ client = OpenAI(api_key=api_key, base_url=base_url)
290
+ else:
291
+ client = AzureOpenAI(
292
+ api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
293
+ )
294
+
295
+ tools = [GET_FRAMES_BY_ID_TOOL_SCHEMA, GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA]
296
+
297
+ get_frames_by_id_with_context = partial(
298
+ get_frames_by_id, all_frame_paths=all_frame_paths
299
+ )
300
+ # MODIFIED: Pass the request queue and results dict to the similarity function
301
+ get_frames_by_similarity_with_context = partial(
302
+ get_frames_by_similarity,
303
+ all_frame_paths=all_frame_paths,
304
+ precomputed_data=embeddings_data,
305
+ request_queue=request_queue,
306
+ results_dict=results_dict,
307
+ )
308
+
309
+ available_functions = {
310
+ "get_frames_by_id": get_frames_by_id_with_context,
311
+ "get_frames_by_similarity": get_frames_by_similarity_with_context,
312
+ }
313
+
314
+ total_frames = len(all_frame_paths)
315
+ duration = data_item.get("video_info", {}).get("duration_minutes", 0) * 60
316
+ initial_prompt = (
317
+ f"The video has {total_frames} frames (ID 1 to {total_frames}) and is {duration:.0f} seconds long. "
318
+ f"Please answer this question:\n{data_item['question']}"
319
+ )
320
+
321
+ messages = [
322
+ {"role": "system", "content": AGENT_SYSTEM_PROMPT},
323
+ {"role": "user", "content": initial_prompt},
324
+ ]
325
+ response_content = None
326
+ max_tool_calls = 10
327
+
328
+ for _ in range(max_tool_calls):
329
+ response_message = call_single_model(
330
+ client,
331
+ messages,
332
+ target_model,
333
+ data_item["key"],
334
+ max_retry_times,
335
+ tools=tools,
336
+ )
337
+ if response_message is None:
338
+ return None
339
+
340
+ messages.append(response_message)
341
+
342
+ if response_message.tool_calls:
343
+ for tool_call in response_message.tool_calls:
344
+ function_name = tool_call.function.name
345
+ function_to_call = available_functions.get(function_name)
346
+ if function_to_call:
347
+ function_args = json.loads(tool_call.function.arguments)
348
+ function_response = function_to_call(**function_args)
349
+
350
+ messages.append(
351
+ {
352
+ "tool_call_id": tool_call.id,
353
+ "role": "tool",
354
+ "name": function_name,
355
+ "content": json.dumps(
356
+ {
357
+ "status": "success",
358
+ "retrieved_frame_count": len(function_response),
359
+ }
360
+ ),
361
+ }
362
+ )
363
+
364
+ user_message_with_frames = [
365
+ {
366
+ "type": "text",
367
+ "text": f"Here are the {len(function_response)} frames from your call to `{function_name}`.",
368
+ }
369
+ ]
370
+ user_message_with_frames.extend(function_response)
371
+ messages.append(
372
+ {"role": "user", "content": user_message_with_frames}
373
+ )
374
+ else:
375
+ response_content = response_message.content
376
+ break
377
+
378
+ if response_content is None:
379
+ final_prompt = "You have reached the maximum number of tool calls. Provide a final answer based on the information gathered so far."
380
+ messages.append({"role": "user", "content": final_prompt})
381
+ final_response = call_single_model(
382
+ client, messages, target_model, data_item["key"], max_retry_times
383
+ )
384
+ response_content = (
385
+ final_response.content
386
+ if final_response
387
+ else "Could not determine an answer after max tool calls."
388
+ )
389
+
390
+ is_correct = False
391
+ model_answer_cleaned = None
392
+ parsed_json = extract_json_from_response(response_content)
393
+ if parsed_json and "answer" in parsed_json:
394
+ model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
395
+ if model_answer_cleaned == data_item["answer"].strip().upper():
396
+ is_correct = True
397
+
398
+ return {
399
+ **data_item,
400
+ "agent_conversation": [
401
+ msg if isinstance(msg, dict) else msg.model_dump() for msg in messages
402
+ ],
403
+ "model_reasoning_and_answer": response_content,
404
+ "model_answer": model_answer_cleaned,
405
+ "is_correct": is_correct,
406
+ }
407
+
408
+
409
+ def encode_image(image_path):
410
+ """Encodes an image file to a base64 string."""
411
+ with open(image_path, "rb") as image_file:
412
+ return base64.b64encode(image_file.read()).decode("utf-8")
413
+
414
+
415
+ # MODIFIED: Function signature updated to accept queues for IPC
416
+ def process_single_data(data_item, args, request_queue, results_dict):
417
+ """Main processing function for a single video, executed by a worker."""
418
+ item_key = data_item["key"]
419
+ try:
420
+ specific_frames_path = os.path.join(args.frames_path, item_key)
421
+ embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt")
422
+
423
+ if not os.path.isdir(specific_frames_path):
424
+ raise FileNotFoundError(
425
+ f"Frame directory not found: {specific_frames_path}"
426
+ )
427
+ if not os.path.exists(embedding_file):
428
+ raise FileNotFoundError(f"Embedding file not found: {embedding_file}")
429
+
430
+ all_frame_paths = sorted(
431
+ [
432
+ os.path.join(specific_frames_path, f)
433
+ for f in os.listdir(specific_frames_path)
434
+ if f.endswith(".jpg")
435
+ ],
436
+ key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
437
+ )
438
+ if not all_frame_paths:
439
+ raise FileNotFoundError(f"No frames found for key '{item_key}'")
440
+
441
+ embeddings_data = torch.load(embedding_file, map_location="cpu")
442
+
443
+ # MODIFIED: Pass queues to the evaluation function
444
+ result = evaluate_single_item_agentic(
445
+ data_item,
446
+ all_frame_paths,
447
+ embeddings_data,
448
+ args.target_model,
449
+ args.api_key,
450
+ args.base_url,
451
+ args.max_retry_times,
452
+ request_queue,
453
+ results_dict,
454
+ )
455
+ return result
456
+
457
+ except Exception as e:
458
+ print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}")
459
+ traceback.print_exc()
460
+ return {
461
+ "key": item_key,
462
+ "uid": data_item.get("uid"),
463
+ "error": str(e),
464
+ "traceback": traceback.format_exc(),
465
+ }
466
+
467
+
468
+ def load_test_data(json_file):
469
+ """Loads the evaluation data from a JSON file."""
470
+ try:
471
+ with open(json_file, "r", encoding="utf-8") as f:
472
+ return json.load(f)
473
+ except FileNotFoundError:
474
+ print(f"Error: Data file not found: {json_file}")
475
+ exit(1)
476
+ except json.JSONDecodeError:
477
+ print(f"Error: Malformed JSON in {json_file}")
478
+ exit(1)
479
+
480
+
481
+ # MODIFIED: This new function runs in its own process, handling embedding requests.
482
+ # It now accepts a model_id and loads the model itself.
483
+ def embedding_server_process(model_id, device, request_queue, results_dict):
484
+ """
485
+ A server process that loads the SigLIP model and continuously fetches
486
+ text queries from a queue, computes their embeddings, and places the
487
+ results in a shared dictionary.
488
+ """
489
+ print(f"Embedding server started on PID {os.getpid()}...")
490
+ print("Loading SigLIP model in the embedding server process...")
491
+ model = AutoModel.from_pretrained(model_id)
492
+ processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
493
+ print("SigLIP model loaded in server.")
494
+
495
+ model.to(device)
496
+ model.eval()
497
+
498
+ while True:
499
+ try:
500
+ request_id, text_query = request_queue.get()
501
+ if text_query == "STOP":
502
+ print("Embedding server received stop signal. Shutting down.")
503
+ break
504
+
505
+ with torch.no_grad():
506
+ text_inputs = processor(
507
+ text=[text_query],
508
+ return_tensors="pt",
509
+ padding=True,
510
+ truncation=True,
511
+ ).to(device)
512
+ query_embedding = model.get_text_features(**text_inputs)
513
+ # Move embedding to CPU before sharing across processes
514
+ results_dict[request_id] = query_embedding.cpu()
515
+ except Exception as e:
516
+ print(f"Error in embedding server: {e}")
517
+ traceback.print_exc()
518
+
519
+
520
+ # MODIFIED: The old init_worker is removed.
521
+ def main():
522
+ """Main function to orchestrate the evaluation framework."""
523
+ args = parse_arguments()
524
+ print("--- Agentic Video QA with Hybrid Retrieval ---")
525
+ print(
526
+ f"Model: {args.target_model}, Data: {args.data_file}, Embeddings: {args.embeddings_path}"
527
+ )
528
+
529
+ # MODIFIED: Changed start method to 'spawn' for safety with CUDA and on macOS/Windows.
530
+ try:
531
+ multiprocessing.set_start_method("spawn", force=True)
532
+ print("Multiprocessing start method set to 'spawn'.")
533
+ except RuntimeError:
534
+ print("Start method already set.")
535
+
536
+ # MODIFIED: Model is no longer loaded in the main process.
537
+ # It will be loaded in the dedicated embedding_server_process.
538
+ device = "cuda" if torch.cuda.is_available() else "cpu"
539
+
540
+ model_name_safe = args.target_model.replace("/", "_")
541
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
542
+ output_prefix = f"{model_name_safe}_{data_filename_base}_agent_hybrid"
543
+ results_output_file = f"{output_prefix}_results.json"
544
+ metrics_output_file = f"{output_prefix}_metrics.json"
545
+ error_log_file = f"{output_prefix}_errors.log"
546
+
547
+ with open(error_log_file, "a", encoding="utf-8") as f:
548
+ f.write(
549
+ f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n"
550
+ )
551
+
552
+ all_test_data = load_test_data(args.data_file)
553
+ existing_results = []
554
+ completed_ids = set()
555
+ if os.path.exists(results_output_file):
556
+ try:
557
+ with open(results_output_file, "r", encoding="utf-8") as f:
558
+ existing_results = json.load(f)
559
+ if isinstance(existing_results, list):
560
+ completed_ids = {
561
+ item["uid"] for item in existing_results if "uid" in item
562
+ }
563
+ print(f"Found {len(completed_ids)} completed tasks. Resuming...")
564
+ else:
565
+ existing_results = []
566
+ except (json.JSONDecodeError, IOError):
567
+ existing_results = []
568
+
569
+ tasks_to_process = [
570
+ item for item in all_test_data if item.get("uid") not in completed_ids
571
+ ]
572
+ if not tasks_to_process:
573
+ print("All tasks are already completed. Calculating final metrics.")
574
+ else:
575
+ print(
576
+ f"Total: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}."
577
+ )
578
+
579
+ all_results = list(existing_results)
580
+
581
+ if tasks_to_process:
582
+ # MODIFIED: Set up Manager, Queues, and the embedding server process
583
+ with multiprocessing.Manager() as manager:
584
+ request_queue = manager.Queue()
585
+ results_dict = manager.dict()
586
+
587
+ # MODIFIED: Start the dedicated embedding server process, passing the model ID.
588
+ embedding_server = multiprocessing.Process(
589
+ target=embedding_server_process,
590
+ args=(
591
+ SIGLIP_MODEL_ID,
592
+ device,
593
+ request_queue,
594
+ results_dict,
595
+ ),
596
+ )
597
+ embedding_server.start()
598
+
599
+ # MODIFIED: The ProcessPoolExecutor no longer needs an initializer for the model
600
+ with concurrent.futures.ProcessPoolExecutor(
601
+ max_workers=args.pool_processes
602
+ ) as executor:
603
+ # MODIFIED: Pass the queues to each worker via partial
604
+ func = partial(
605
+ process_single_data,
606
+ args=args,
607
+ request_queue=request_queue,
608
+ results_dict=results_dict,
609
+ )
610
+ results_iterator = executor.map(func, tasks_to_process)
611
+ for result in tqdm(
612
+ results_iterator,
613
+ total=len(tasks_to_process),
614
+ desc="Processing Videos",
615
+ ):
616
+ if result:
617
+ if "error" in result:
618
+ with open(error_log_file, "a", encoding="utf-8") as f:
619
+ f.write(
620
+ f"Error on key {result.get('key', 'N/A')}:\n Error: {result['error']}\n Traceback: {result['traceback']}\n---\n"
621
+ )
622
+ all_results.append(result)
623
+ if len(all_results) % 10 == 0:
624
+ save_json_file(all_results, results_output_file)
625
+
626
+ # MODIFIED: Gracefully shut down the embedding server
627
+ print("All tasks processed. Sending stop signal to embedding server.")
628
+ request_queue.put((None, "STOP"))
629
+ embedding_server.join()
630
+
631
+ print("\n\nProcessing complete.")
632
+ save_json_file(all_results, results_output_file)
633
+ print(f"Detailed results saved to: {results_output_file}")
634
+
635
+ final_metrics = calculate_metrics(all_results)
636
+ save_json_file(final_metrics, metrics_output_file)
637
+ print(f"\nMetrics saved to: {metrics_output_file}")
638
+ print(json.dumps(final_metrics, indent=4))
639
+
640
+
641
+ if __name__ == "__main__":
642
+ main()
643
+
main_uniform_sampling.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import argparse
5
+ import time
6
+ import re
7
+ from datetime import datetime
8
+ from functools import partial
9
+ from openai import AzureOpenAI, OpenAI
10
+ from volcenginesdkarkruntime import Ark
11
+ from multiprocessing import Pool, Manager, Lock
12
+
13
+ # New prompt template for multiple-choice questions with reasoning
14
+ REASONING_MULTIPLE_CHOICE_TEMPLATE = """
15
+ You are an AI assistant evaluating video frames to answer a multiple-choice question.
16
+ The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D).
17
+
18
+ First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion.
19
+ After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
20
+
21
+ Your output should follow this format exactly:
22
+ <Your step-by-step reasoning here>
23
+ ```json
24
+ {"answer": "A"}
25
+ ```
26
+ Do not include any other text after the JSON block.
27
+ """
28
+
29
+
30
+ def parse_arguments():
31
+ """
32
+ Parse command line arguments for evaluation configuration.
33
+
34
+ Returns:
35
+ argparse.Namespace: Parsed command line arguments
36
+ """
37
+ parser = argparse.ArgumentParser(description="Video QA Evaluation Framework")
38
+
39
+ # Model configuration
40
+ parser.add_argument(
41
+ "--target-model",
42
+ "-tm",
43
+ type=str,
44
+ required=True,
45
+ help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)",
46
+ )
47
+
48
+ # Data configuration
49
+ parser.add_argument(
50
+ "--frame-num",
51
+ "-fn",
52
+ type=int,
53
+ default=32,
54
+ help="Number of frames to uniformly sample from each video (default: 32)",
55
+ )
56
+ parser.add_argument(
57
+ "--frames-path",
58
+ "-fp",
59
+ type=str,
60
+ required=True,
61
+ help="Absolute path to the base directory containing video frame folders.",
62
+ )
63
+ parser.add_argument(
64
+ "--data-file",
65
+ "-df",
66
+ type=str,
67
+ required=True,
68
+ help="Absolute path to the JSON file containing the evaluation dataset.",
69
+ )
70
+
71
+ # Processing configuration
72
+ parser.add_argument(
73
+ "--max-retry-times",
74
+ "-mr",
75
+ type=int,
76
+ default=10,
77
+ help="Maximum number of retries for API calls (default: 10)",
78
+ )
79
+ parser.add_argument(
80
+ "--pool-processes",
81
+ "-pp",
82
+ type=int,
83
+ default=20,
84
+ help="Number of parallel processes for evaluation (default: 20)",
85
+ )
86
+
87
+ # API configuration
88
+ parser.add_argument(
89
+ "--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
90
+ )
91
+ parser.add_argument(
92
+ "--api_key", type=str, required=True, help="Azure OpenAI API key."
93
+ )
94
+
95
+ return parser.parse_args()
96
+
97
+
98
+ def save_json_file(data, output_file):
99
+ """
100
+ Save data to a JSON file.
101
+
102
+ Args:
103
+ data (dict or list): Data to be saved.
104
+ output_file (str): Path to the output file.
105
+ """
106
+ with open(output_file, "w", encoding="utf-8") as f:
107
+ json.dump(data, f, indent=4)
108
+
109
+
110
+ def extract_json_from_response(response):
111
+ """
112
+ Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block.
113
+
114
+ Args:
115
+ response (str): The raw response string from the model.
116
+
117
+ Returns:
118
+ dict or None: Parsed JSON object or None if no valid JSON block is found.
119
+ """
120
+ if not response:
121
+ return None
122
+ try:
123
+ # Regex to find the content inside ```json ... ```
124
+ match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
125
+ if match:
126
+ json_str = match.group(1)
127
+ return json.loads(json_str)
128
+ return None
129
+ except (json.JSONDecodeError, IndexError):
130
+ return None
131
+
132
+
133
+ def calculate_metrics(results):
134
+ """
135
+ Calculate evaluation metrics from the results.
136
+
137
+ Args:
138
+ results (list): List of results with 'is_correct' field.
139
+
140
+ Returns:
141
+ dict: Dictionary containing calculated metrics.
142
+ """
143
+ total_samples = len(results)
144
+ if total_samples == 0:
145
+ return {
146
+ "total_samples": 0,
147
+ "answered_samples": 0,
148
+ "correct_answers": 0,
149
+ "accuracy": 0.0,
150
+ }
151
+
152
+ answered_samples = sum(1 for x in results if x.get("model_answer") is not None)
153
+ correct_answers = sum(1 for x in results if x.get("is_correct"))
154
+
155
+ accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
156
+
157
+ metrics = {
158
+ "total_samples": total_samples,
159
+ "answered_samples": answered_samples,
160
+ "correct_answers": correct_answers,
161
+ "accuracy": accuracy,
162
+ }
163
+
164
+ return metrics
165
+
166
+
167
+ def call_single_model(client, messages, model, item_id, max_retry_times):
168
+ """
169
+ Make a single API call to the specified model with retry logic.
170
+
171
+ Args:
172
+ client: OpenAI client instance.
173
+ messages (list): List of messages for the API call.
174
+ model (str): Model name to use.
175
+ item_id (str): ID of the item being processed (for error logging).
176
+ max_retry_times (int): Maximum number of retries.
177
+
178
+ Returns:
179
+ str or None: Model response or None if all retries failed.
180
+ """
181
+ if "doubao" in model:
182
+ max_tokens = 32768
183
+ else:
184
+ max_tokens = 65535
185
+ retry_times = 0
186
+ while retry_times < max_retry_times:
187
+ try:
188
+ # Set max_tokens to a larger value to allow for reasoning
189
+ completion = client.chat.completions.create(
190
+ model=model, messages=messages, max_tokens=max_tokens
191
+ )
192
+ return completion.choices[0].message.content
193
+ except Exception as e:
194
+ retry_times += 1
195
+ print(
196
+ f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
197
+ )
198
+ if retry_times == max_retry_times:
199
+ error_log_file = f"error_log_{model.replace('/', '_')}.txt"
200
+ with open(error_log_file, "a") as f:
201
+ f.write(
202
+ f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n"
203
+ )
204
+ return None
205
+ time.sleep(5) # Wait before retrying
206
+
207
+
208
+ def evaluate_single_item(
209
+ data_item, frames, target_model, api_key, base_url, max_retry_times
210
+ ):
211
+ """
212
+ Evaluate a single data item using the target model and perform exact match.
213
+
214
+ Args:
215
+ data_item (dict): Dictionary containing question and answer data.
216
+ frames (list): List of encoded video frames.
217
+ target_model (str): Model to be evaluated.
218
+ api_key (str): API key.
219
+ base_url (str): API base URL.
220
+ max_retry_times (int): Maximum number of retries.
221
+
222
+ Returns:
223
+ dict: Evaluation result.
224
+ """
225
+ if "ark" in base_url:
226
+ client = Ark(
227
+ base_url=base_url,
228
+ api_key=api_key,
229
+ )
230
+ elif "aliyun" in base_url or "127.0.0.1" in base_url:
231
+ client = OpenAI(api_key=api_key, base_url=base_url)
232
+ else:
233
+ client = AzureOpenAI(
234
+ api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
235
+ )
236
+
237
+ # Construct messages for the model using the new template
238
+ messages = [
239
+ {"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE},
240
+ {
241
+ "role": "user",
242
+ "content": [
243
+ {"type": "text", "text": "Here are the video frames:"},
244
+ *frames,
245
+ {"type": "text", "text": f"Question: {data_item['question']}"},
246
+ ],
247
+ },
248
+ ]
249
+
250
+ response = call_single_model(
251
+ client, messages, target_model, data_item["key"], max_retry_times
252
+ )
253
+
254
+ is_correct = False
255
+ model_answer_cleaned = None
256
+ parsed_json = None
257
+
258
+ if response:
259
+ parsed_json = extract_json_from_response(response)
260
+ if parsed_json and "answer" in parsed_json:
261
+ model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
262
+ gold_answer = data_item["answer"].strip().upper()
263
+ if model_answer_cleaned == gold_answer:
264
+ is_correct = True
265
+
266
+ # Create result dictionary
267
+ result = {
268
+ **data_item,
269
+ "model_reasoning_and_answer": response,
270
+ "model_answer_raw": parsed_json.get("answer") if parsed_json else None,
271
+ "model_answer": model_answer_cleaned,
272
+ "is_correct": is_correct,
273
+ }
274
+
275
+ return result
276
+
277
+
278
+ def encode_image(image_path):
279
+ """
280
+ Encode an image file to base64 string.
281
+
282
+ Args:
283
+ image_path (str): Path to the image file.
284
+
285
+ Returns:
286
+ str: Base64 encoded image string.
287
+ """
288
+ with open(image_path, "rb") as image_file:
289
+ return base64.b64encode(image_file.read()).decode("utf-8")
290
+
291
+
292
+ def process_frames(frames_path, frame_num):
293
+ """
294
+ Process and uniformly sample video frames from a directory, then encode them.
295
+
296
+ Args:
297
+ frames_path (str): Path to the directory containing video frames.
298
+ frame_num (int): The number of frames to sample.
299
+
300
+ Returns:
301
+ list: List of encoded frame objects for API consumption.
302
+ """
303
+ if not os.path.isdir(frames_path):
304
+ print(f"Warning: Frame directory not found at {frames_path}")
305
+ return []
306
+
307
+ frame_files = [
308
+ f
309
+ for f in os.listdir(frames_path)
310
+ if f.startswith("frame_") and f.endswith(".jpg")
311
+ ]
312
+ # Sort frames numerically based on the ID in frame_{id}.jpg
313
+ frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
314
+
315
+ frame_path_list = [os.path.join(frames_path, f) for f in frame_files]
316
+ total_frames = len(frame_path_list)
317
+
318
+ if total_frames == 0:
319
+ return []
320
+
321
+ # Uniformly sample frame paths
322
+ if total_frames > frame_num:
323
+ indices = [int(i * total_frames / frame_num) for i in range(frame_num)]
324
+ sampled_paths = [frame_path_list[i] for i in indices]
325
+ else:
326
+ sampled_paths = frame_path_list # Use all frames if fewer than requested
327
+
328
+ # Encode only the sampled frames
329
+ base64_images = [encode_image(path) for path in sampled_paths]
330
+
331
+ # Create frame objects for API payload
332
+ return [
333
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
334
+ for b64_img in base64_images
335
+ ]
336
+
337
+
338
+ def process_single_data(
339
+ data_item, args, shared_results, progress_counter, total_items, locks
340
+ ):
341
+ """
342
+ Process a single data item in a multiprocessing context.
343
+
344
+ Args:
345
+ data_item (dict): Single data item to process.
346
+ args: Command line arguments.
347
+ shared_results: Shared list for storing results.
348
+ progress_counter: Shared counter for progress tracking.
349
+ total_items (int): Total number of items to process.
350
+ locks (dict): Dictionary of locks for thread-safe operations.
351
+ """
352
+ item_key = data_item["key"]
353
+ try:
354
+ # Construct path to the specific video's frames folder
355
+ specific_frames_path = os.path.join(args.frames_path, item_key)
356
+ frames = process_frames(specific_frames_path, args.frame_num)
357
+
358
+ if not frames:
359
+ raise FileNotFoundError(
360
+ f"No frames found or processed for key '{item_key}' at path '{specific_frames_path}'"
361
+ )
362
+
363
+ result = evaluate_single_item(
364
+ data_item,
365
+ frames,
366
+ args.target_model,
367
+ args.api_key,
368
+ args.base_url,
369
+ args.max_retry_times,
370
+ )
371
+
372
+ if result is not None:
373
+ with locks["results"]:
374
+ shared_results.append(result)
375
+ # Define output file names inside the worker
376
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[
377
+ 0
378
+ ]
379
+ model_name_safe = args.target_model.replace("/", "_")
380
+ output_prefix = (
381
+ f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames"
382
+ )
383
+ results_output_file = f"{output_prefix}_results.json"
384
+ # Save the entire updated list of results after each case is processed
385
+ save_json_file(list(shared_results), results_output_file)
386
+
387
+ except Exception as e:
388
+ print(f"Error processing video key {item_key}: {str(e)}")
389
+ with locks["file"]:
390
+ error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
391
+ with open(error_log_file, "a") as f:
392
+ f.write(f"Critical error processing video key {item_key}: {str(e)}\n")
393
+ finally:
394
+ # Always update progress counter
395
+ with locks["counter"]:
396
+ progress_counter.value += 1
397
+ print(
398
+ f"\rProcessed: {progress_counter.value}/{total_items} videos...",
399
+ end="",
400
+ flush=True,
401
+ )
402
+
403
+
404
+ def load_test_data(json_file):
405
+ """
406
+ Load test data from a JSON file.
407
+
408
+ Args:
409
+ json_file (str): Path to the JSON file.
410
+
411
+ Returns:
412
+ list: List of test data items.
413
+ """
414
+ try:
415
+ with open(json_file, "r", encoding="utf-8") as f:
416
+ return json.load(f)
417
+ except FileNotFoundError:
418
+ print(f"Error: Data file not found at {json_file}")
419
+ exit(1)
420
+ except json.JSONDecodeError:
421
+ print(f"Error: Could not decode JSON from {json_file}")
422
+ exit(1)
423
+
424
+
425
+ def main():
426
+ """
427
+ Main function to run the video QA evaluation framework.
428
+ """
429
+ args = parse_arguments()
430
+
431
+ print("--- Evaluation Configuration ---")
432
+ print(f"Target Model: {args.target_model}")
433
+ print(f"Frames to Sample: {args.frame_num}")
434
+ print(f"Frames Base Path: {args.frames_path}")
435
+ print(f"Data File: {args.data_file}")
436
+ print(f"Parallel Processes: {args.pool_processes}")
437
+ print("---------------------------------")
438
+
439
+ # Initialize error log file
440
+ error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
441
+ with open(error_log_file, "w") as f:
442
+ f.write(
443
+ f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n"
444
+ )
445
+
446
+ # Define output file names
447
+ data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
448
+ model_name_safe = args.target_model.replace("/", "_")
449
+ output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames"
450
+
451
+ results_output_file = f"{output_prefix}_results.json"
452
+ metrics_output_file = f"{output_prefix}_metrics.json"
453
+
454
+ # Load data
455
+ test_data = load_test_data(args.data_file)
456
+ total_videos = len(test_data)
457
+ print(f"\nLoaded {total_videos} videos to process.")
458
+
459
+ # Set up multiprocessing
460
+ with Manager() as manager:
461
+ shared_results = manager.list()
462
+ progress_counter = manager.Value("i", 0)
463
+
464
+ locks = {
465
+ "results": manager.Lock(),
466
+ "file": manager.Lock(),
467
+ "counter": manager.Lock(),
468
+ }
469
+
470
+ # Create a partial function with fixed arguments for the worker pool
471
+ process_func = partial(
472
+ process_single_data,
473
+ args=args,
474
+ shared_results=shared_results,
475
+ progress_counter=progress_counter,
476
+ total_items=total_videos,
477
+ locks=locks,
478
+ )
479
+
480
+ # Run processing in parallel
481
+ with Pool(processes=args.pool_processes) as pool:
482
+ pool.map(process_func, test_data)
483
+
484
+ # Convert shared list to a regular list for final processing
485
+ all_results = list(shared_results)
486
+
487
+ print(f"\n\nProcessing complete for model: {args.target_model}")
488
+
489
+ # Calculate and save final metrics
490
+ final_metrics = calculate_metrics(all_results)
491
+ save_json_file(final_metrics, metrics_output_file)
492
+ print(f"\nMetrics saved to: {metrics_output_file}")
493
+ print(json.dumps(final_metrics, indent=4))
494
+
495
+ # Save final results
496
+ save_json_file(all_results, results_output_file)
497
+ print(f"Detailed results saved to: {results_output_file}")
498
+
499
+
500
+ if __name__ == "__main__":
501
+ main()
offline_compute_similarity.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import argparse
5
+ import tempfile
6
+ import glob
7
+ from tqdm import tqdm
8
+ from transformers import AutoModel, AutoProcessor
9
+ from torch.nn.functional import cosine_similarity
10
+ import torch.multiprocessing as mp
11
+
12
+ # --- 配置 ---
13
+ MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
14
+
15
+
16
+ def parse_arguments():
17
+ """解析命令行参数"""
18
+ parser = argparse.ArgumentParser(
19
+ description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。"
20
+ )
21
+ parser.add_argument(
22
+ "--data-file",
23
+ "-df",
24
+ type=str,
25
+ required=True,
26
+ help="包含评估数据集的JSON文件的绝对路径。",
27
+ )
28
+ parser.add_argument(
29
+ "--embeddings-path",
30
+ "-ep",
31
+ type=str,
32
+ required=True,
33
+ help="包含预计算嵌入.pt文件的目录的绝对路径。",
34
+ )
35
+ parser.add_argument(
36
+ "--output-file",
37
+ "-o",
38
+ type=str,
39
+ required=True,
40
+ help="用于保存最终相似度分数的JSON文件路径。",
41
+ )
42
+ return parser.parse_args()
43
+
44
+
45
+ def load_test_data(json_file):
46
+ """从JSON文件加载测试数据"""
47
+ try:
48
+ with open(json_file, "r", encoding="utf-8") as f:
49
+ return json.load(f)
50
+ except FileNotFoundError:
51
+ print(f"错误: 在 {json_file} 未找到数据文件")
52
+ exit(1)
53
+ except json.JSONDecodeError:
54
+ print(f"错误: 无法从 {json_file} 解码JSON")
55
+ exit(1)
56
+ return []
57
+
58
+
59
+ def save_json_file(data, output_file):
60
+ """将数据保存到JSON文件"""
61
+ # os.makedirs(os.path.dirname(output_file), exist_ok=True)
62
+ with open(output_file, "w", encoding="utf-8") as f:
63
+ json.dump(data, f, indent=4)
64
+ print(f"\n成功将最终相似度结果保存到 {output_file}")
65
+
66
+
67
+ def process_question_chunk(args_tuple):
68
+ """
69
+ 工作函数,用于处理一批问题并增量保存结果。
70
+ """
71
+ data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple
72
+ device = f"cuda:{gpu_id}"
73
+
74
+ # 为此工作进程定义一个唯一的临时输出文件
75
+ temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl")
76
+
77
+ # 只需要模型来计算文本特征
78
+ model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
79
+ processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
80
+
81
+ progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
82
+
83
+ # 缓存已加载的嵌入以避免重复IO
84
+ embedding_cache = {}
85
+
86
+ with open(temp_output_file, "a", encoding="utf-8") as f_out:
87
+ for data_item in progress_bar:
88
+ item_key = data_item["key"]
89
+ question_key = data_item["uid"]
90
+ question = data_item["question"].split("\n(A)")[0]
91
+
92
+ embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt")
93
+ if not os.path.exists(embedding_file_path):
94
+ progress_bar.write(
95
+ f"Warning: Embedding file not found for '{item_key}', skipping."
96
+ )
97
+ continue
98
+
99
+ try:
100
+ # 从缓存或文件中加载嵌入
101
+ if item_key not in embedding_cache:
102
+ loaded_data = torch.load(embedding_file_path, map_location="cpu")
103
+ embedding_cache[item_key] = {
104
+ "filenames": loaded_data["filenames"],
105
+ "embeddings": loaded_data["embeddings"],
106
+ }
107
+
108
+ frame_files = embedding_cache[item_key]["filenames"]
109
+ frame_embeddings = embedding_cache[item_key]["embeddings"].to(device)
110
+
111
+ with torch.no_grad():
112
+ # --- 文本嵌入 ---
113
+ text_inputs = processor(
114
+ text=[question],
115
+ return_tensors="pt",
116
+ padding=True,
117
+ truncation=True,
118
+ ).to(device)
119
+ question_embedding = model.get_text_features(**text_inputs)
120
+
121
+ # --- 相似度计算和排序 ---
122
+ similarities = cosine_similarity(
123
+ question_embedding, frame_embeddings
124
+ )
125
+ scored_frames = sorted(
126
+ zip(frame_files, similarities.cpu().numpy()),
127
+ key=lambda x: x[1],
128
+ reverse=True,
129
+ )
130
+ sorted_frame_filenames = [frame[0] for frame in scored_frames]
131
+
132
+ single_result = {question_key: sorted_frame_filenames}
133
+ f_out.write(json.dumps(single_result) + "\n")
134
+
135
+ except Exception as e:
136
+ progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}")
137
+
138
+
139
+ def main():
140
+ """主函数,用于协调多GPU处理"""
141
+ args = parse_arguments()
142
+
143
+ num_gpus = torch.cuda.device_count()
144
+ if num_gpus == 0:
145
+ print("错误: 未找到启用CUDA的GPU。正在退出。")
146
+ exit(1)
147
+
148
+ print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...")
149
+
150
+ test_data = load_test_data(args.data_file)
151
+ if not test_data:
152
+ return
153
+
154
+ chunk_size = (len(test_data) + num_gpus - 1) // num_gpus
155
+ data_chunks = [
156
+ test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size)
157
+ ]
158
+
159
+ with tempfile.TemporaryDirectory() as temp_dir:
160
+ print(f"使用临时目录存储中间结果: {temp_dir}")
161
+
162
+ process_args = [
163
+ (data_chunks[i], args.embeddings_path, i, temp_dir)
164
+ for i in range(len(data_chunks))
165
+ ]
166
+
167
+ with mp.Pool(processes=num_gpus) as pool:
168
+ pool.map(process_question_chunk, process_args)
169
+
170
+ # --- 从临时文件合并和保存最终结果 ---
171
+ print("\n\n所有GPU进程已完成。正在从临时文件合并结果...")
172
+ final_similarity_results = {}
173
+
174
+ temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl"))
175
+
176
+ for temp_file in tqdm(temp_files, desc="合并文件"):
177
+ with open(temp_file, "r", encoding="utf-8") as f:
178
+ for line in f:
179
+ try:
180
+ data = json.loads(line)
181
+ final_similarity_results.update(data)
182
+ except json.JSONDecodeError:
183
+ print(f"警告: 跳过 {temp_file} 中的损坏行")
184
+
185
+ save_json_file(final_similarity_results, args.output_file)
186
+ print(f"总共处理的项目数: {len(final_similarity_results)}")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ mp.set_start_method("spawn", force=True)
191
+ main()
utils/count_frames.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ with open(
5
+ "/mnt/bn/ziyang-storage-cloudnative-hl/VideoSimpleQA/o4-mini-2025-04-16_lvbench_agent_results.json",
6
+ "r",
7
+ ) as f:
8
+ data = json.load(f)
9
+
10
+ # print(data[0])
11
+ # import pdb; pdb.set_trace()
12
+ # print(data[1])
13
+
14
+ frames_count = []
15
+ # for d in data:
16
+ # num_frames = 0
17
+ # for turn in d["agent_conversation"]:
18
+ # if turn["role"] == "assistant" and turn["tool_calls"]:
19
+ # for tool_call in turn["tool_calls"]:
20
+ # frame_ids = json.loads(tool_call["function"]["arguments"])["frame_ids"]
21
+ # num_frames += len(frame_ids)
22
+ # frames_count.append(num_frames)
23
+
24
+ for d in data:
25
+ num_frames = 0
26
+ for turn in d["agent_conversation"]:
27
+ if turn["role"] == "user" and type(turn['content']) == list:
28
+ for item in turn['content']:
29
+ if item['type'] == 'image_url':
30
+ num_frames += 1
31
+ frames_count.append(num_frames)
32
+
33
+ print(f"mean frames: {sum(frames_count) / len(frames_count)}")
34
+ print(f"max frames: {max(frames_count)}")
35
+ print(f"min frames: {min(frames_count)}")