Initial upload of all project files
Browse files- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.14.0/8806091925620224500 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- compute_video_emb.py +154 -0
- main_adaptive_sampling.py +447 -0
- main_agent.py +521 -0
- main_i2i_ret.py +548 -0
- main_mcot.py +608 -0
- main_new_agent.py +643 -0
- main_uniform_sampling.py +501 -0
- offline_compute_similarity.py +191 -0
- utils/count_frames.py +35 -0
.ruff_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically created by ruff.
|
| 2 |
+
*
|
.ruff_cache/0.14.0/8806091925620224500
ADDED
|
Binary file (95 Bytes). View file
|
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
compute_video_emb.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transformers import AutoModel, AutoProcessor
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
import glob
|
| 10 |
+
|
| 11 |
+
# --- 配置 ---
|
| 12 |
+
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
|
| 13 |
+
BATCH_SIZE = 1024 # 根据你的 GPU VRAM 调整
|
| 14 |
+
|
| 15 |
+
def parse_arguments():
|
| 16 |
+
"""解析命令行参数"""
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description="步骤 1: 使用 SigLIP (多GPU) 预计算所有视频帧的嵌入."
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--frames-path",
|
| 22 |
+
"-fp",
|
| 23 |
+
type=str,
|
| 24 |
+
required=True,
|
| 25 |
+
help="包含所有视频帧文件夹的基础目录的绝对路径。",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--output-dir",
|
| 29 |
+
"-o",
|
| 30 |
+
type=str,
|
| 31 |
+
required=True,
|
| 32 |
+
help="用于保存嵌入.pt文件的输出目录路径。",
|
| 33 |
+
)
|
| 34 |
+
return parser.parse_args()
|
| 35 |
+
|
| 36 |
+
class FrameDataset(Dataset):
|
| 37 |
+
"""一个用于高效加载视频帧的PyTorch Dataset"""
|
| 38 |
+
def __init__(self, frame_paths):
|
| 39 |
+
self.frame_paths = frame_paths
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return len(self.frame_paths)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
path = self.frame_paths[idx]
|
| 46 |
+
try:
|
| 47 |
+
image = Image.open(path).convert("RGB")
|
| 48 |
+
return image
|
| 49 |
+
except Exception:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def collate_fn(batch):
|
| 53 |
+
"""自定义collate函数,用于从批次中过滤掉None值"""
|
| 54 |
+
batch = [item for item in batch if item is not None]
|
| 55 |
+
if not batch:
|
| 56 |
+
return None
|
| 57 |
+
return batch
|
| 58 |
+
|
| 59 |
+
def process_video_chunk(args_tuple):
|
| 60 |
+
"""
|
| 61 |
+
工作函数,用于在特定GPU上处理一批视频。
|
| 62 |
+
"""
|
| 63 |
+
video_dirs_chunk, frames_base_path, gpu_id, output_dir = args_tuple
|
| 64 |
+
device = f"cuda:{gpu_id}"
|
| 65 |
+
|
| 66 |
+
# 在工作进程中为指定的GPU加载模型和处理器
|
| 67 |
+
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
|
| 68 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
|
| 69 |
+
|
| 70 |
+
progress_bar = tqdm(video_dirs_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
|
| 71 |
+
|
| 72 |
+
for video_dir in progress_bar:
|
| 73 |
+
video_name = os.path.basename(video_dir)
|
| 74 |
+
output_path = os.path.join(output_dir, f"{video_name}.pt")
|
| 75 |
+
|
| 76 |
+
# 如果文件已存在,则跳过,以支持断点续算
|
| 77 |
+
if os.path.exists(output_path):
|
| 78 |
+
progress_bar.write(f"Skipping {video_name}, embeddings already exist.")
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
frame_files = [f for f in os.listdir(video_dir) if f.endswith(".jpg")]
|
| 82 |
+
if not frame_files:
|
| 83 |
+
continue
|
| 84 |
+
frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
| 85 |
+
frame_paths = [os.path.join(video_dir, f) for f in frame_files]
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
dataset = FrameDataset(frame_paths)
|
| 90 |
+
loader = DataLoader(
|
| 91 |
+
dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0,
|
| 92 |
+
pin_memory=True, collate_fn=collate_fn
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
all_frame_embeddings = []
|
| 96 |
+
for image_batch in loader:
|
| 97 |
+
if image_batch is None:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
image_inputs = processor(images=image_batch, return_tensors="pt").to(device)
|
| 101 |
+
frame_embeddings = model.get_image_features(**image_inputs)
|
| 102 |
+
all_frame_embeddings.append(frame_embeddings)
|
| 103 |
+
|
| 104 |
+
if not all_frame_embeddings:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
all_frame_embeddings = torch.cat(all_frame_embeddings, dim=0)
|
| 108 |
+
|
| 109 |
+
# 将张量移动到CPU以便保存,避免后续加载时出现CUDA问题
|
| 110 |
+
data_to_save = {
|
| 111 |
+
'filenames': frame_files,
|
| 112 |
+
'embeddings': all_frame_embeddings.cpu()
|
| 113 |
+
}
|
| 114 |
+
torch.save(data_to_save, output_path)
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
progress_bar.write(f"Error on GPU-{gpu_id} for video '{video_name}': {e}")
|
| 118 |
+
|
| 119 |
+
def main():
|
| 120 |
+
"""主函数,用于协调多GPU处理"""
|
| 121 |
+
args = parse_arguments()
|
| 122 |
+
|
| 123 |
+
num_gpus = torch.cuda.device_count()
|
| 124 |
+
if num_gpus == 0:
|
| 125 |
+
print("错误: 未找到启用CUDA的GPU。正在退出。")
|
| 126 |
+
exit(1)
|
| 127 |
+
|
| 128 |
+
print(f"找到 {num_gpus} 个GPU。开始并行处理...")
|
| 129 |
+
|
| 130 |
+
# 创建输出目录
|
| 131 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
# 获取所有唯一的视频目录
|
| 134 |
+
video_dirs = [d for d in glob.glob(os.path.join(args.frames_path, '*')) if os.path.isdir(d)]
|
| 135 |
+
|
| 136 |
+
if not video_dirs:
|
| 137 |
+
print(f"错误: 在 {args.frames_path} 中未找到视频目录。")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# 将视频目录分成块,每个GPU一块
|
| 141 |
+
chunk_size = (len(video_dirs) + num_gpus - 1) // num_gpus
|
| 142 |
+
video_chunks = [video_dirs[i:i + chunk_size] for i in range(0, len(video_dirs), chunk_size)]
|
| 143 |
+
|
| 144 |
+
# 为每个工作进程准���参数
|
| 145 |
+
process_args = [(video_chunks[i], args.frames_path, i, args.output_dir) for i in range(len(video_chunks))]
|
| 146 |
+
|
| 147 |
+
with mp.Pool(processes=num_gpus) as pool:
|
| 148 |
+
pool.map(process_video_chunk, process_args)
|
| 149 |
+
|
| 150 |
+
print("\n所有视频帧嵌入已计算并保存。")
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
mp.set_start_method('spawn', force=True)
|
| 154 |
+
main()
|
main_adaptive_sampling.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from functools import partial
|
| 9 |
+
from openai import AzureOpenAI, OpenAI
|
| 10 |
+
from volcenginesdkarkruntime import Ark
|
| 11 |
+
from multiprocessing import Pool, Manager, Lock
|
| 12 |
+
|
| 13 |
+
# New prompt template for multiple-choice questions with reasoning
|
| 14 |
+
REASONING_MULTIPLE_CHOICE_TEMPLATE = """
|
| 15 |
+
You are an AI assistant evaluating video frames to answer a multiple-choice question.
|
| 16 |
+
The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D).
|
| 17 |
+
|
| 18 |
+
First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion.
|
| 19 |
+
After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
|
| 20 |
+
|
| 21 |
+
Your output should follow this format exactly:
|
| 22 |
+
<Your step-by-step reasoning here>
|
| 23 |
+
```json
|
| 24 |
+
{"answer": "A"}
|
| 25 |
+
```
|
| 26 |
+
Do not include any other text after the JSON block.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_arguments():
|
| 31 |
+
"""
|
| 32 |
+
Parse command line arguments for evaluation configuration.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
argparse.Namespace: Parsed command line arguments
|
| 36 |
+
"""
|
| 37 |
+
parser = argparse.ArgumentParser(
|
| 38 |
+
description="Video QA Evaluation with Pre-computed Similarity Frame Selection"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Model configuration
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--target-model",
|
| 44 |
+
"-tm",
|
| 45 |
+
type=str,
|
| 46 |
+
required=True,
|
| 47 |
+
help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Data configuration
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--frame-num",
|
| 53 |
+
"-fn",
|
| 54 |
+
type=int,
|
| 55 |
+
default=32,
|
| 56 |
+
help="Number of most similar frames to select for each video (default: 32)",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--frames-path",
|
| 60 |
+
"-fp",
|
| 61 |
+
type=str,
|
| 62 |
+
required=True,
|
| 63 |
+
help="Absolute path to the base directory containing video frame folders.",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--data-file",
|
| 67 |
+
"-df",
|
| 68 |
+
type=str,
|
| 69 |
+
required=True,
|
| 70 |
+
help="Absolute path to the JSON file containing the evaluation dataset.",
|
| 71 |
+
)
|
| 72 |
+
# --- MODIFIED ARGUMENT ---
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--similarity-file",
|
| 75 |
+
"-sf",
|
| 76 |
+
type=str,
|
| 77 |
+
required=True,
|
| 78 |
+
help="Absolute path to the pre-computed similarity JSON file (e.g., lv_bench_similarity.json).",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Processing configuration
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--max-retry-times",
|
| 84 |
+
"-mr",
|
| 85 |
+
type=int,
|
| 86 |
+
default=10,
|
| 87 |
+
help="Maximum number of retries for API calls (default: 10)",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--pool-processes",
|
| 91 |
+
"-pp",
|
| 92 |
+
type=int,
|
| 93 |
+
default=20,
|
| 94 |
+
help="Number of parallel processes for evaluation (default: 20)",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# API configuration
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--api_key", type=str, required=True, help="Azure OpenAI API key."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return parser.parse_args()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_json_file(data, output_file):
|
| 109 |
+
"""
|
| 110 |
+
Save data to a JSON file.
|
| 111 |
+
"""
|
| 112 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 113 |
+
json.dump(data, f, indent=4)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def extract_json_from_response(response):
|
| 117 |
+
"""
|
| 118 |
+
Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block.
|
| 119 |
+
"""
|
| 120 |
+
if not response:
|
| 121 |
+
return None
|
| 122 |
+
try:
|
| 123 |
+
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 124 |
+
if match:
|
| 125 |
+
json_str = match.group(1)
|
| 126 |
+
return json.loads(json_str)
|
| 127 |
+
return None
|
| 128 |
+
except (json.JSONDecodeError, IndexError):
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def calculate_metrics(results):
|
| 133 |
+
"""
|
| 134 |
+
Calculate evaluation metrics from the results.
|
| 135 |
+
"""
|
| 136 |
+
total_samples = len(results)
|
| 137 |
+
if total_samples == 0:
|
| 138 |
+
return {
|
| 139 |
+
"total_samples": 0,
|
| 140 |
+
"answered_samples": 0,
|
| 141 |
+
"correct_answers": 0,
|
| 142 |
+
"accuracy": 0.0,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
answered_samples = sum(1 for x in results if x.get("model_answer") is not None)
|
| 146 |
+
correct_answers = sum(1 for x in results if x.get("is_correct"))
|
| 147 |
+
|
| 148 |
+
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"total_samples": total_samples,
|
| 152 |
+
"answered_samples": answered_samples,
|
| 153 |
+
"correct_answers": correct_answers,
|
| 154 |
+
"accuracy": accuracy,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def call_single_model(client, messages, model, item_id, max_retry_times):
|
| 159 |
+
"""
|
| 160 |
+
Make a single API call to the specified model with retry logic.
|
| 161 |
+
"""
|
| 162 |
+
if "doubao" in model:
|
| 163 |
+
max_tokens = 32768
|
| 164 |
+
else:
|
| 165 |
+
max_tokens = 65535
|
| 166 |
+
retry_times = 0
|
| 167 |
+
while retry_times < max_retry_times:
|
| 168 |
+
try:
|
| 169 |
+
completion = client.chat.completions.create(
|
| 170 |
+
model=model, messages=messages, max_tokens=max_tokens
|
| 171 |
+
)
|
| 172 |
+
return completion.choices[0].message.content
|
| 173 |
+
except Exception as e:
|
| 174 |
+
retry_times += 1
|
| 175 |
+
print(
|
| 176 |
+
f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
|
| 177 |
+
)
|
| 178 |
+
if retry_times == max_retry_times:
|
| 179 |
+
error_log_file = f"error_log_{model.replace('/', '_')}.txt"
|
| 180 |
+
with open(error_log_file, "a") as f:
|
| 181 |
+
f.write(
|
| 182 |
+
f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n"
|
| 183 |
+
)
|
| 184 |
+
return None
|
| 185 |
+
time.sleep(5)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def evaluate_single_item(
|
| 189 |
+
data_item, frames, target_model, api_key, base_url, max_retry_times
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Evaluate a single data item using the target model.
|
| 193 |
+
"""
|
| 194 |
+
if "ark" in base_url:
|
| 195 |
+
client = Ark(base_url=base_url, api_key=api_key)
|
| 196 |
+
elif "aliyun" in base_url or "127.0.0.1" in base_url:
|
| 197 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 198 |
+
else:
|
| 199 |
+
client = AzureOpenAI(
|
| 200 |
+
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
messages = [
|
| 204 |
+
{"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE},
|
| 205 |
+
{
|
| 206 |
+
"role": "user",
|
| 207 |
+
"content": [
|
| 208 |
+
{"type": "text", "text": "Here are the video frames:"},
|
| 209 |
+
*frames,
|
| 210 |
+
{"type": "text", "text": f"Question: {data_item['question']}"},
|
| 211 |
+
],
|
| 212 |
+
},
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
response = call_single_model(
|
| 216 |
+
client, messages, target_model, data_item["key"], max_retry_times
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
is_correct = False
|
| 220 |
+
model_answer_cleaned = None
|
| 221 |
+
parsed_json = None
|
| 222 |
+
|
| 223 |
+
if response:
|
| 224 |
+
parsed_json = extract_json_from_response(response)
|
| 225 |
+
if parsed_json and "answer" in parsed_json:
|
| 226 |
+
model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
|
| 227 |
+
gold_answer = data_item["answer"].strip().upper()
|
| 228 |
+
if model_answer_cleaned == gold_answer:
|
| 229 |
+
is_correct = True
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
**data_item,
|
| 233 |
+
"model_reasoning_and_answer": response,
|
| 234 |
+
"model_answer_raw": parsed_json.get("answer") if parsed_json else None,
|
| 235 |
+
"model_answer": model_answer_cleaned,
|
| 236 |
+
"is_correct": is_correct,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def encode_image(image_path):
|
| 241 |
+
"""
|
| 242 |
+
Encode an image file to base64 string.
|
| 243 |
+
"""
|
| 244 |
+
with open(image_path, "rb") as image_file:
|
| 245 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# --- MODIFIED: New function for selecting frames based on pre-computed similarity file ---
|
| 249 |
+
def process_frames_from_similarity_file(
|
| 250 |
+
frames_base_path, frame_num, data_item, similarity_data
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Select and encode the top N frames using a pre-computed similarity file.
|
| 254 |
+
"""
|
| 255 |
+
item_key = data_item["key"]
|
| 256 |
+
question_uid = str(data_item["uid"])
|
| 257 |
+
|
| 258 |
+
# Retrieve the sorted list of frame filenames for the current question
|
| 259 |
+
sorted_filenames = similarity_data.get(question_uid)
|
| 260 |
+
|
| 261 |
+
if not sorted_filenames:
|
| 262 |
+
print(
|
| 263 |
+
f"Warning: No similarity data found for question UID '{question_uid}', skipping."
|
| 264 |
+
)
|
| 265 |
+
return []
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
# Select the top N filenames
|
| 269 |
+
num_frames_to_select = min(frame_num, len(sorted_filenames))
|
| 270 |
+
selected_filenames = sorted_filenames[:num_frames_to_select]
|
| 271 |
+
selected_ids = [int(f.split(".")[0].split("_")[-1]) for f in selected_filenames]
|
| 272 |
+
selected_ids = sorted(selected_ids)
|
| 273 |
+
selected_filenames = [f"frame_{i:06d}.jpg" for i in selected_ids]
|
| 274 |
+
|
| 275 |
+
# Construct full paths for the selected frames
|
| 276 |
+
video_frames_path = os.path.join(frames_base_path, item_key)
|
| 277 |
+
sampled_paths = [os.path.join(video_frames_path, f) for f in selected_filenames]
|
| 278 |
+
|
| 279 |
+
# Encode the selected frames
|
| 280 |
+
base64_images = [encode_image(path) for path in sampled_paths]
|
| 281 |
+
|
| 282 |
+
return [
|
| 283 |
+
{
|
| 284 |
+
"type": "image_url",
|
| 285 |
+
"image_url": {"url": f"data:image/jpeg;base64,{b64_img}"},
|
| 286 |
+
}
|
| 287 |
+
for b64_img in base64_images
|
| 288 |
+
]
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"Error during frame processing for key '{item_key}': {e}")
|
| 291 |
+
return []
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def process_single_data(
|
| 295 |
+
data_item,
|
| 296 |
+
args,
|
| 297 |
+
shared_results,
|
| 298 |
+
progress_counter,
|
| 299 |
+
total_items,
|
| 300 |
+
locks,
|
| 301 |
+
similarity_data,
|
| 302 |
+
):
|
| 303 |
+
"""
|
| 304 |
+
Process a single data item in a multiprocessing context.
|
| 305 |
+
"""
|
| 306 |
+
item_key = data_item["key"]
|
| 307 |
+
try:
|
| 308 |
+
# --- MODIFIED: Call the new frame selection function ---
|
| 309 |
+
frames = process_frames_from_similarity_file(
|
| 310 |
+
args.frames_path, args.frame_num, data_item, similarity_data
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if not frames:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
f"No frames were processed from similarity file for key '{item_key}'"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
result = evaluate_single_item(
|
| 319 |
+
data_item,
|
| 320 |
+
frames,
|
| 321 |
+
args.target_model,
|
| 322 |
+
args.api_key,
|
| 323 |
+
args.base_url,
|
| 324 |
+
args.max_retry_times,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if result is not None:
|
| 328 |
+
with locks["results"]:
|
| 329 |
+
shared_results.append(result)
|
| 330 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[
|
| 331 |
+
0
|
| 332 |
+
]
|
| 333 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 334 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar"
|
| 335 |
+
results_output_file = f"{output_prefix}_results.json"
|
| 336 |
+
save_json_file(list(shared_results), results_output_file)
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error processing video key {item_key}: {str(e)}")
|
| 340 |
+
with locks["file"]:
|
| 341 |
+
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
|
| 342 |
+
with open(error_log_file, "a") as f:
|
| 343 |
+
f.write(f"Critical error processing video key {item_key}: {str(e)}\n")
|
| 344 |
+
finally:
|
| 345 |
+
with locks["counter"]:
|
| 346 |
+
progress_counter.value += 1
|
| 347 |
+
print(
|
| 348 |
+
f"\rProcessed: {progress_counter.value}/{total_items} videos...",
|
| 349 |
+
end="",
|
| 350 |
+
flush=True,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def load_test_data(json_file):
|
| 355 |
+
"""
|
| 356 |
+
Load test data from a JSON file.
|
| 357 |
+
"""
|
| 358 |
+
try:
|
| 359 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 360 |
+
return json.load(f)
|
| 361 |
+
except FileNotFoundError:
|
| 362 |
+
print(f"Error: Data file not found at {json_file}")
|
| 363 |
+
exit(1)
|
| 364 |
+
except json.JSONDecodeError:
|
| 365 |
+
print(f"Error: Could not decode JSON from {json_file}")
|
| 366 |
+
exit(1)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def main():
|
| 370 |
+
"""
|
| 371 |
+
Main function to run the video QA evaluation framework.
|
| 372 |
+
"""
|
| 373 |
+
args = parse_arguments()
|
| 374 |
+
|
| 375 |
+
print("--- Evaluation Configuration ---")
|
| 376 |
+
print(f"Target Model: {args.target_model}")
|
| 377 |
+
print(f"Frames to Sample (by pre-computed similarity): {args.frame_num}")
|
| 378 |
+
print(f"Frames Base Path: {args.frames_path}")
|
| 379 |
+
print(f"Similarity File: {args.similarity_file}") # Print new arg
|
| 380 |
+
print(f"Data File: {args.data_file}")
|
| 381 |
+
print(f"Parallel Processes: {args.pool_processes}")
|
| 382 |
+
print("---------------------------------")
|
| 383 |
+
|
| 384 |
+
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
|
| 385 |
+
with open(error_log_file, "w") as f:
|
| 386 |
+
f.write(
|
| 387 |
+
f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 391 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 392 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar"
|
| 393 |
+
|
| 394 |
+
results_output_file = f"{output_prefix}_results.json"
|
| 395 |
+
metrics_output_file = f"{output_prefix}_metrics.json"
|
| 396 |
+
|
| 397 |
+
# Load test data and similarity data
|
| 398 |
+
test_data = load_test_data(args.data_file)
|
| 399 |
+
try:
|
| 400 |
+
with open(args.similarity_file, "r", encoding="utf-8") as f:
|
| 401 |
+
similarity_data = json.load(f)
|
| 402 |
+
except FileNotFoundError:
|
| 403 |
+
print(f"Error: Similarity file not found at {args.similarity_file}")
|
| 404 |
+
exit(1)
|
| 405 |
+
|
| 406 |
+
total_videos = len(test_data)
|
| 407 |
+
print(f"\nLoaded {total_videos} videos to process.")
|
| 408 |
+
|
| 409 |
+
with Manager() as manager:
|
| 410 |
+
shared_results = manager.list()
|
| 411 |
+
progress_counter = manager.Value("i", 0)
|
| 412 |
+
locks = {
|
| 413 |
+
"results": manager.Lock(),
|
| 414 |
+
"file": manager.Lock(),
|
| 415 |
+
"counter": manager.Lock(),
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# Create a partial function with fixed arguments for the worker pool
|
| 419 |
+
process_func = partial(
|
| 420 |
+
process_single_data,
|
| 421 |
+
args=args,
|
| 422 |
+
shared_results=shared_results,
|
| 423 |
+
progress_counter=progress_counter,
|
| 424 |
+
total_items=total_videos,
|
| 425 |
+
locks=locks,
|
| 426 |
+
similarity_data=similarity_data,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Run processing in parallel
|
| 430 |
+
with Pool(processes=args.pool_processes) as pool:
|
| 431 |
+
pool.map(process_func, test_data)
|
| 432 |
+
|
| 433 |
+
all_results = list(shared_results)
|
| 434 |
+
|
| 435 |
+
print(f"\n\nProcessing complete for model: {args.target_model}")
|
| 436 |
+
|
| 437 |
+
final_metrics = calculate_metrics(all_results)
|
| 438 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 439 |
+
print(f"\nMetrics saved to: {metrics_output_file}")
|
| 440 |
+
print(json.dumps(final_metrics, indent=4))
|
| 441 |
+
|
| 442 |
+
save_json_file(all_results, results_output_file)
|
| 443 |
+
print(f"Detailed results saved to: {results_output_file}")
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
main()
|
main_agent.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import traceback
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from functools import partial
|
| 10 |
+
from openai import AzureOpenAI, OpenAI
|
| 11 |
+
from volcenginesdkarkruntime import Ark
|
| 12 |
+
import concurrent.futures
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# New system prompt for the agent
|
| 16 |
+
AGENT_SYSTEM_PROMPT = """
|
| 17 |
+
You are an intelligent AI assistant specialized in video question answering.
|
| 18 |
+
Your task is to answer a multiple-choice question based on a video.
|
| 19 |
+
|
| 20 |
+
You must use the `get_frames_by_id` tool to request specific frames to view.
|
| 21 |
+
You will be told the total number of frames available in the video (e.g., "The video has 1250 frames, numbered 1 to 1250.").
|
| 22 |
+
|
| 23 |
+
Your strategy should be efficient:
|
| 24 |
+
1. Based on the task query, think about which part of the video will be related, and then get the frames of this part. If the query’s description is fairly general and you can’t effectively infer the temporal regions where the target visual evidence might appear, you can first uniformly sample some frames for analysis to identify the time intervals where the target visual evidence is likely to appear.
|
| 25 |
+
2. Analyze the retrieved frames and the user's question.
|
| 26 |
+
3. If you don't have enough information, form a hypothesis about where the answer might be and use the tool again to request more specific frames from that segment.
|
| 27 |
+
4. Continue this process of reasoning and tool use until you are confident in your answer. Avoid requesting all frames at once.
|
| 28 |
+
5. Please make sure that you find the relevant visual cues and then answer the question instead of guessing the answer.
|
| 29 |
+
6. You can access 10 frames at most in each tool call.
|
| 30 |
+
|
| 31 |
+
Please note that if you have insufficient visual information at the beginning, you can first sample more frames uniformly to understand the video (e.g., sampling 10 frames per tool call). You can then gradually refine the subsequent steps and adopt a coarse-to-fine strategy overall.
|
| 32 |
+
For example, the question is "What is the main subject of the video?"
|
| 33 |
+
You can first sample 10 frames uniformly from the video (e.g., frame 100, 200, ..., 1200).
|
| 34 |
+
After analyzing these frames, you might notice that the main subject is a person in the middle of the screen (between frame 500 and 600).
|
| 35 |
+
You can then sample more frames from this region (e.g., frame 500, 520, ..., 590) to get more detailed information.
|
| 36 |
+
Finally, you can reason based on the visual cues you have gathered and provide the final answer.
|
| 37 |
+
This process might be multi-turn.
|
| 38 |
+
|
| 39 |
+
After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
|
| 40 |
+
|
| 41 |
+
Remember that You can access 10 frames at most in each tool call.
|
| 42 |
+
|
| 43 |
+
Your output should follow this format exactly:
|
| 44 |
+
<Your step-by-step reasoning here>
|
| 45 |
+
```json
|
| 46 |
+
{"answer": "X"}
|
| 47 |
+
```
|
| 48 |
+
Do not include any other text after the JSON block.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# Tool schema for the get_frames_by_id function
|
| 52 |
+
GET_FRAMES_TOOL_SCHEMA = {
|
| 53 |
+
"type": "function",
|
| 54 |
+
"function": {
|
| 55 |
+
"name": "get_frames_by_id",
|
| 56 |
+
"description": "Retrieves specific video frames by their numerical IDs. Use this to get visual information from the video.",
|
| 57 |
+
"parameters": {
|
| 58 |
+
"type": "object",
|
| 59 |
+
"properties": {
|
| 60 |
+
"frame_ids": {
|
| 61 |
+
"type": "array",
|
| 62 |
+
"items": {"type": "integer"},
|
| 63 |
+
"description": "A list of frame numbers to retrieve. You can access 10 frames at most in each tool call.",
|
| 64 |
+
},
|
| 65 |
+
},
|
| 66 |
+
"required": ["frame_ids"],
|
| 67 |
+
},
|
| 68 |
+
},
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def parse_arguments():
|
| 73 |
+
"""
|
| 74 |
+
Parse command line arguments for evaluation configuration.
|
| 75 |
+
"""
|
| 76 |
+
parser = argparse.ArgumentParser(
|
| 77 |
+
description="Video QA Evaluation Framework with Agentic Frame Selection (Refactored)"
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--target-model",
|
| 81 |
+
"-tm",
|
| 82 |
+
type=str,
|
| 83 |
+
required=True,
|
| 84 |
+
help="Model to be evaluated (e.g., gpt-4o)",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--frames-path",
|
| 88 |
+
"-fp",
|
| 89 |
+
type=str,
|
| 90 |
+
required=True,
|
| 91 |
+
help="Absolute path to the base directory for video frames.",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--data-file",
|
| 95 |
+
"-df",
|
| 96 |
+
type=str,
|
| 97 |
+
required=True,
|
| 98 |
+
help="Absolute path to the JSON evaluation dataset.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--max-retry-times",
|
| 102 |
+
"-mr",
|
| 103 |
+
type=int,
|
| 104 |
+
default=10,
|
| 105 |
+
help="Maximum retries for API calls.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--pool-processes",
|
| 109 |
+
"-pp",
|
| 110 |
+
type=int,
|
| 111 |
+
default=20,
|
| 112 |
+
help="Number of parallel processes.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--api_key", type=str, required=True, help="Azure OpenAI API key."
|
| 119 |
+
)
|
| 120 |
+
return parser.parse_args()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def save_json_file(data, output_file):
|
| 124 |
+
"""Saves data to a JSON file."""
|
| 125 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 126 |
+
json.dump(data, f, indent=4)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def extract_json_from_response(response):
|
| 130 |
+
"""Extracts a JSON object from a model's response string."""
|
| 131 |
+
if not response:
|
| 132 |
+
return None
|
| 133 |
+
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 134 |
+
if match:
|
| 135 |
+
try:
|
| 136 |
+
return json.loads(match.group(1))
|
| 137 |
+
except (json.JSONDecodeError, IndexError):
|
| 138 |
+
return None
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def calculate_metrics(results):
|
| 143 |
+
"""Calculates accuracy and other metrics from evaluation results."""
|
| 144 |
+
# Filter out potential error results before calculating
|
| 145 |
+
valid_results = [r for r in results if "error" not in r]
|
| 146 |
+
total_samples = len(valid_results)
|
| 147 |
+
|
| 148 |
+
if total_samples == 0:
|
| 149 |
+
return {
|
| 150 |
+
"total_samples": 0,
|
| 151 |
+
"answered_samples": 0,
|
| 152 |
+
"correct_answers": 0,
|
| 153 |
+
"accuracy": 0.0,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
answered_samples = sum(
|
| 157 |
+
1 for x in valid_results if x.get("model_answer") is not None
|
| 158 |
+
)
|
| 159 |
+
correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
|
| 160 |
+
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
|
| 161 |
+
|
| 162 |
+
return {
|
| 163 |
+
"total_samples": total_samples,
|
| 164 |
+
"answered_samples": answered_samples,
|
| 165 |
+
"correct_answers": correct_answers,
|
| 166 |
+
"accuracy": accuracy,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
|
| 171 |
+
"""Makes a single API call with retry logic and tool support."""
|
| 172 |
+
if "o4" in model:
|
| 173 |
+
params = {"model": model, "messages": messages, "max_tokens": 65535}
|
| 174 |
+
elif "Qwen" in model:
|
| 175 |
+
params = {
|
| 176 |
+
"model": model,
|
| 177 |
+
"messages": messages,
|
| 178 |
+
"max_tokens": 2048,
|
| 179 |
+
"temperature": 0,
|
| 180 |
+
}
|
| 181 |
+
else:
|
| 182 |
+
params = {"model": model, "messages": messages, "max_tokens": 32768}
|
| 183 |
+
if tools:
|
| 184 |
+
params["tools"] = tools
|
| 185 |
+
params["tool_choice"] = "auto"
|
| 186 |
+
|
| 187 |
+
retry_times = 0
|
| 188 |
+
while retry_times < max_retry_times:
|
| 189 |
+
try:
|
| 190 |
+
completion = client.chat.completions.create(**params)
|
| 191 |
+
return completion.choices[0].message
|
| 192 |
+
except Exception as e:
|
| 193 |
+
retry_times += 1
|
| 194 |
+
print(
|
| 195 |
+
f"API Error for item {item_id}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
|
| 196 |
+
)
|
| 197 |
+
if retry_times == max_retry_times:
|
| 198 |
+
# Instead of writing to a file here, we'll let the worker return the error
|
| 199 |
+
raise e # Reraise the exception to be caught by the worker's main try-except block
|
| 200 |
+
time.sleep(5)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_frames_by_id(frame_ids: list, all_frame_paths: list):
|
| 204 |
+
"""Tool implementation: Retrieves and formats frames based on a list of IDs."""
|
| 205 |
+
retrieved_frames = []
|
| 206 |
+
frame_map = {
|
| 207 |
+
int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p
|
| 208 |
+
for p in all_frame_paths
|
| 209 |
+
if re.search(r"frame_(\d+)\.jpg", os.path.basename(p))
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
for fid in frame_ids:
|
| 213 |
+
path = frame_map.get(fid)
|
| 214 |
+
if path and os.path.exists(path):
|
| 215 |
+
b64_image = encode_image(path)
|
| 216 |
+
retrieved_frames.append(
|
| 217 |
+
{
|
| 218 |
+
"type": "image_url",
|
| 219 |
+
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
return retrieved_frames
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def evaluate_single_item_agentic(
|
| 226 |
+
data_item, all_frame_paths, target_model, api_key, base_url, max_retry_times
|
| 227 |
+
):
|
| 228 |
+
"""Evaluates a single item using an agentic loop for dynamic frame selection."""
|
| 229 |
+
if "ark" in base_url:
|
| 230 |
+
client = Ark(
|
| 231 |
+
base_url=base_url,
|
| 232 |
+
api_key=api_key,
|
| 233 |
+
)
|
| 234 |
+
elif "aliyun" in base_url or "127.0.0.1" in base_url:
|
| 235 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 236 |
+
else:
|
| 237 |
+
client = AzureOpenAI(
|
| 238 |
+
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
tools = [GET_FRAMES_TOOL_SCHEMA]
|
| 242 |
+
available_functions = {"get_frames_by_id": get_frames_by_id}
|
| 243 |
+
|
| 244 |
+
total_frames = len(all_frame_paths)
|
| 245 |
+
minutes = data_item["video_info"]["duration_minutes"]
|
| 246 |
+
seconds = int(minutes * 60)
|
| 247 |
+
initial_prompt = (
|
| 248 |
+
f"The video has {total_frames} frames, numbered 1 to {total_frames}. This video is {seconds} seconds long. "
|
| 249 |
+
f"Please answer the following question:\n{data_item['question']}"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
messages = [
|
| 253 |
+
{"role": "system", "content": AGENT_SYSTEM_PROMPT},
|
| 254 |
+
{"role": "user", "content": initial_prompt},
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
response_content = None
|
| 258 |
+
max_tool_calls = 10
|
| 259 |
+
|
| 260 |
+
for i in range(max_tool_calls):
|
| 261 |
+
response_message = call_single_model(
|
| 262 |
+
client,
|
| 263 |
+
messages,
|
| 264 |
+
target_model,
|
| 265 |
+
data_item["key"],
|
| 266 |
+
max_retry_times,
|
| 267 |
+
tools=tools,
|
| 268 |
+
)
|
| 269 |
+
if response_message is None:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
messages.append(response_message.model_dump())
|
| 273 |
+
|
| 274 |
+
if response_message.tool_calls:
|
| 275 |
+
for tool_call in response_message.tool_calls:
|
| 276 |
+
function_name = tool_call.function.name
|
| 277 |
+
function_to_call = available_functions.get(function_name)
|
| 278 |
+
if function_to_call:
|
| 279 |
+
function_args = json.loads(tool_call.function.arguments)
|
| 280 |
+
retrieved_frames = function_to_call(
|
| 281 |
+
**function_args, all_frame_paths=all_frame_paths
|
| 282 |
+
)
|
| 283 |
+
tool_response_content = [
|
| 284 |
+
{
|
| 285 |
+
"type": "text",
|
| 286 |
+
"text": f"Here are the frames you requested (IDs: {function_args.get('frame_ids', [])}).",
|
| 287 |
+
}
|
| 288 |
+
]
|
| 289 |
+
tool_response_content.extend(retrieved_frames)
|
| 290 |
+
messages.append(
|
| 291 |
+
{
|
| 292 |
+
"tool_call_id": tool_call.id,
|
| 293 |
+
"role": "tool",
|
| 294 |
+
"name": function_name,
|
| 295 |
+
"content": json.dumps(
|
| 296 |
+
{
|
| 297 |
+
"status": "success",
|
| 298 |
+
"retrieved_frame_count": len(retrieved_frames),
|
| 299 |
+
}
|
| 300 |
+
),
|
| 301 |
+
}
|
| 302 |
+
)
|
| 303 |
+
messages.append({"role": "user", "content": tool_response_content})
|
| 304 |
+
else:
|
| 305 |
+
response_content = response_message.content
|
| 306 |
+
break
|
| 307 |
+
|
| 308 |
+
if response_content is None and response_message and response_message.tool_calls:
|
| 309 |
+
print(
|
| 310 |
+
f"\nMax tool calls reached for item {data_item['key']}. Forcing a final answer."
|
| 311 |
+
)
|
| 312 |
+
final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far."
|
| 313 |
+
messages.append({"role": "user", "content": final_prompt})
|
| 314 |
+
final_response_message = call_single_model(
|
| 315 |
+
client,
|
| 316 |
+
messages,
|
| 317 |
+
target_model,
|
| 318 |
+
data_item["key"],
|
| 319 |
+
max_retry_times,
|
| 320 |
+
tools=None,
|
| 321 |
+
)
|
| 322 |
+
if final_response_message:
|
| 323 |
+
messages.append(final_response_message)
|
| 324 |
+
response_content = final_response_message.content
|
| 325 |
+
elif response_content is None and response_message:
|
| 326 |
+
response_content = response_message.content
|
| 327 |
+
|
| 328 |
+
is_correct = False
|
| 329 |
+
model_answer_cleaned = None
|
| 330 |
+
parsed_json = extract_json_from_response(response_content)
|
| 331 |
+
if parsed_json and "answer" in parsed_json:
|
| 332 |
+
model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
|
| 333 |
+
gold_answer = data_item["answer"].strip().upper()
|
| 334 |
+
if model_answer_cleaned == gold_answer:
|
| 335 |
+
is_correct = True
|
| 336 |
+
return {
|
| 337 |
+
**data_item,
|
| 338 |
+
"agent_conversation": [
|
| 339 |
+
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in messages
|
| 340 |
+
],
|
| 341 |
+
"model_reasoning_and_answer": response_content,
|
| 342 |
+
"model_answer": model_answer_cleaned,
|
| 343 |
+
"is_correct": is_correct,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def encode_image(image_path):
|
| 348 |
+
"""Encodes an image file to a base64 string."""
|
| 349 |
+
with open(image_path, "rb") as image_file:
|
| 350 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def process_single_data(data_item, args):
|
| 354 |
+
"""
|
| 355 |
+
Main processing function for a single video.
|
| 356 |
+
This function is executed by each worker process. It is self-contained.
|
| 357 |
+
"""
|
| 358 |
+
item_key = data_item["key"]
|
| 359 |
+
try:
|
| 360 |
+
specific_frames_path = os.path.join(args.frames_path, item_key)
|
| 361 |
+
if not os.path.isdir(specific_frames_path):
|
| 362 |
+
raise FileNotFoundError(f"Frame directory not found for key '{item_key}'")
|
| 363 |
+
|
| 364 |
+
all_frame_paths = sorted(
|
| 365 |
+
[
|
| 366 |
+
os.path.join(specific_frames_path, f)
|
| 367 |
+
for f in os.listdir(specific_frames_path)
|
| 368 |
+
if f.endswith(".jpg")
|
| 369 |
+
],
|
| 370 |
+
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if not all_frame_paths:
|
| 374 |
+
raise FileNotFoundError(f"No frames found for key '{item_key}'")
|
| 375 |
+
|
| 376 |
+
# The core evaluation logic is called here
|
| 377 |
+
result = evaluate_single_item_agentic(
|
| 378 |
+
data_item,
|
| 379 |
+
all_frame_paths,
|
| 380 |
+
args.target_model,
|
| 381 |
+
args.api_key,
|
| 382 |
+
args.base_url,
|
| 383 |
+
args.max_retry_times,
|
| 384 |
+
)
|
| 385 |
+
return result
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
# If any error occurs, catch it and return an error dictionary.
|
| 389 |
+
# This prevents the worker process from crashing and allows the main
|
| 390 |
+
# process to log the error gracefully.
|
| 391 |
+
print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}")
|
| 392 |
+
traceback.print_exc()
|
| 393 |
+
return {
|
| 394 |
+
"key": item_key,
|
| 395 |
+
"uid": data_item.get("uid"),
|
| 396 |
+
"error": str(e),
|
| 397 |
+
"traceback": traceback.format_exc(),
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def load_test_data(json_file):
|
| 402 |
+
"""Loads the evaluation data from a JSON file."""
|
| 403 |
+
try:
|
| 404 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 405 |
+
return json.load(f)
|
| 406 |
+
except FileNotFoundError:
|
| 407 |
+
print(f"Error: Data file not found: {json_file}")
|
| 408 |
+
exit(1)
|
| 409 |
+
except json.JSONDecodeError:
|
| 410 |
+
print(f"Error: Malformed JSON in {json_file}")
|
| 411 |
+
exit(1)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def main():
|
| 415 |
+
"""Main function to orchestrate the evaluation framework."""
|
| 416 |
+
args = parse_arguments()
|
| 417 |
+
|
| 418 |
+
print("--- Agentic Video QA Evaluation (Refactored) ---")
|
| 419 |
+
print(f"Target Model: {args.target_model}")
|
| 420 |
+
print(f"Frames Base Path: {args.frames_path}")
|
| 421 |
+
print(f"Data File: {args.data_file}")
|
| 422 |
+
|
| 423 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 424 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 425 |
+
|
| 426 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_agent_results"
|
| 427 |
+
results_output_file = f"{output_prefix}.json"
|
| 428 |
+
metrics_output_file = f"{output_prefix}_metrics.json"
|
| 429 |
+
error_log_file = f"{output_prefix}_errors.log"
|
| 430 |
+
|
| 431 |
+
with open(error_log_file, "a", encoding="utf-8") as f:
|
| 432 |
+
f.write(
|
| 433 |
+
f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
all_test_data = load_test_data(args.data_file)
|
| 437 |
+
completed_ids = set()
|
| 438 |
+
existing_results = []
|
| 439 |
+
|
| 440 |
+
if os.path.exists(results_output_file):
|
| 441 |
+
try:
|
| 442 |
+
with open(results_output_file, "r", encoding="utf-8") as f:
|
| 443 |
+
existing_results = json.load(f)
|
| 444 |
+
if isinstance(existing_results, list):
|
| 445 |
+
completed_ids = {
|
| 446 |
+
item["uid"] for item in existing_results if "uid" in item
|
| 447 |
+
}
|
| 448 |
+
print(
|
| 449 |
+
f"Found {len(completed_ids)} completed tasks in '{results_output_file}'. Resuming..."
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
existing_results = []
|
| 453 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 454 |
+
print(f"Warning: Could not read results file: {e}. Starting fresh.")
|
| 455 |
+
existing_results = []
|
| 456 |
+
|
| 457 |
+
tasks_to_process = [
|
| 458 |
+
item for item in all_test_data if item.get("uid") not in completed_ids
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
if not tasks_to_process:
|
| 462 |
+
print("All tasks are already completed. Calculating final metrics.")
|
| 463 |
+
final_metrics = calculate_metrics(existing_results)
|
| 464 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 465 |
+
print(f"\nFinal metrics saved to: {metrics_output_file}")
|
| 466 |
+
print(json.dumps(final_metrics, indent=4))
|
| 467 |
+
return
|
| 468 |
+
|
| 469 |
+
print(
|
| 470 |
+
f"Total tasks: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}."
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# This list will hold all results, both old and new.
|
| 474 |
+
all_results = list(existing_results)
|
| 475 |
+
|
| 476 |
+
# Using ProcessPoolExecutor for robust, modern multiprocessing.
|
| 477 |
+
with concurrent.futures.ProcessPoolExecutor(
|
| 478 |
+
max_workers=args.pool_processes
|
| 479 |
+
) as executor:
|
| 480 |
+
# partial is used to pass the constant `args` to each call of process_single_data
|
| 481 |
+
func = partial(process_single_data, args=args)
|
| 482 |
+
|
| 483 |
+
# executor.map processes the tasks in parallel.
|
| 484 |
+
# tqdm provides a progress bar.
|
| 485 |
+
results_iterator = executor.map(func, tasks_to_process)
|
| 486 |
+
|
| 487 |
+
for result in tqdm(
|
| 488 |
+
results_iterator, total=len(tasks_to_process), desc="Processing Videos"
|
| 489 |
+
):
|
| 490 |
+
if result:
|
| 491 |
+
if "error" in result:
|
| 492 |
+
# Log errors centrally
|
| 493 |
+
with open(error_log_file, "a", encoding="utf-8") as f:
|
| 494 |
+
f.write(f"Error on key {result.get('key', 'N/A')}:\n")
|
| 495 |
+
f.write(f" Error: {result['error']}\n")
|
| 496 |
+
f.write(f" Traceback: {result['traceback']}\n---\n")
|
| 497 |
+
|
| 498 |
+
# Append every result (success or error) to the main list
|
| 499 |
+
all_results.append(result)
|
| 500 |
+
|
| 501 |
+
# Periodically save results for resilience
|
| 502 |
+
if len(all_results) % 10 == 0:
|
| 503 |
+
save_json_file(all_results, results_output_file)
|
| 504 |
+
|
| 505 |
+
print("\n\nProcessing complete.")
|
| 506 |
+
|
| 507 |
+
# Final save of all combined results
|
| 508 |
+
save_json_file(all_results, results_output_file)
|
| 509 |
+
print(f"Detailed results saved to: {results_output_file}")
|
| 510 |
+
|
| 511 |
+
# Calculate and save final metrics
|
| 512 |
+
final_metrics = calculate_metrics(all_results)
|
| 513 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 514 |
+
print(f"\nMetrics saved to: {metrics_output_file}")
|
| 515 |
+
print(json.dumps(final_metrics, indent=4))
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
if __name__ == "__main__":
|
| 519 |
+
# To run this script, you'll need to install tqdm:
|
| 520 |
+
# pip install tqdm
|
| 521 |
+
main()
|
main_i2i_ret.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import traceback
|
| 8 |
+
import uuid
|
| 9 |
+
import multiprocessing
|
| 10 |
+
import concurrent.futures
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import requests
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from openai import AzureOpenAI, OpenAI
|
| 19 |
+
from volcenginesdkarkruntime import Ark
|
| 20 |
+
from transformers import AutoModel, AutoProcessor
|
| 21 |
+
from torch.nn.functional import cosine_similarity
|
| 22 |
+
|
| 23 |
+
# --- Model and Configuration Constants ---
|
| 24 |
+
|
| 25 |
+
# SigLIP model for generating image embeddings
|
| 26 |
+
SIGLIP_MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
|
| 27 |
+
# Number of Top-K frames to retrieve for each generated image
|
| 28 |
+
TOP_K_FRAMES = 8
|
| 29 |
+
|
| 30 |
+
# --- Prompt Templates ---
|
| 31 |
+
|
| 32 |
+
# Step 1: System prompt for VLM to analyze video and question, then generate image creation requests
|
| 33 |
+
# The goal of this prompt is not to answer the question, but to plan which keyframes need to be "seen"
|
| 34 |
+
STEP_1_PLANNING_PROMPT = """
|
| 35 |
+
You are a professional video analyst. Your task is to analyze a question and a few initial video sample frames, then plan what keyframes you need to see to answer the question.
|
| 36 |
+
|
| 37 |
+
Do not answer the question directly. Your output must be a JSON array, where each object represents a keyframe you wish to generate.
|
| 38 |
+
Each object must contain the following two keys:
|
| 39 |
+
1. `reference_image_id`: An integer representing the ID of a frame already provided to you that you wish to use as a generation reference. This ID must be one of the IDs provided by the user.
|
| 40 |
+
2. `prompt`: A detailed text description to tell the image generation model what kind of scene to draw.
|
| 41 |
+
|
| 42 |
+
For example, if the question is "Where did the man in the red shirt eventually go?", you might generate the following JSON:
|
| 43 |
+
```json
|
| 44 |
+
[
|
| 45 |
+
{
|
| 46 |
+
"reference_image_id": 120,
|
| 47 |
+
"prompt": "A man in a red shirt is walking towards an open door, with a background similar to the reference image."
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"reference_image_id": 120,
|
| 51 |
+
"prompt": "A man in a red shirt has already walked out the door, and the door is closing, with a background similar to the reference image."
|
| 52 |
+
}
|
| 53 |
+
]
|
| 54 |
+
```
|
| 55 |
+
Your output must strictly adhere to this JSON format.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Step 3: System prompt for VLM to perform final reasoning and answer based on all retrieved keyframes
|
| 59 |
+
STEP_3_FINAL_ANSWER_PROMPT = """
|
| 60 |
+
You are an AI video question-answering assistant.
|
| 61 |
+
The user will provide you with a series of keyframes retrieved from a video and a question.
|
| 62 |
+
|
| 63 |
+
First, please provide a step-by-step reasoning process, analyzing these keyframes and deriving your conclusion.
|
| 64 |
+
After your reasoning, provide the final answer. The answer must be in a JSON code block, and the JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'.
|
| 65 |
+
|
| 66 |
+
Your output format must be strictly as follows:
|
| 67 |
+
<Your step-by-step reasoning process>
|
| 68 |
+
```json
|
| 69 |
+
{"answer": "A"}
|
| 70 |
+
```
|
| 71 |
+
Do not include any other text after the JSON block.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def parse_arguments():
|
| 76 |
+
"""Parse command-line arguments"""
|
| 77 |
+
parser = argparse.ArgumentParser(
|
| 78 |
+
description="Image Retrieval-based Video QA Workflow"
|
| 79 |
+
)
|
| 80 |
+
# Model Configuration
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--target-model", "-tm", type=str, required=True, help="VLM model for inference (e.g., gpt-4o)"
|
| 83 |
+
)
|
| 84 |
+
# Data Path Configuration
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--frames-path", "-fp", type=str, required=True, help="Root directory containing video frame folders"
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--data-file", "-df", type=str, required=True, help="JSON data file containing evaluation questions"
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--embeddings-path", "-ep", type=str, required=True, help="Directory containing pre-computed embeddings for all video frames"
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--output-path", "-op", type=str, default="./results_image_retrieval", help="Directory to store all outputs and generated images"
|
| 96 |
+
)
|
| 97 |
+
# Workflow Parameters
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--initial-frames-num", "-ifn", type=int, default=8, help="Number of initial uniformly sampled frames for Step 1"
|
| 100 |
+
)
|
| 101 |
+
# Execution Configuration
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--max-retry-times", "-mr", type=int, default=10, help="Maximum number of retries for API calls"
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--pool-processes", "-pp", type=int, default=10, help="Number of parallel processes"
|
| 107 |
+
)
|
| 108 |
+
# API Credentials
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--base_url", type=str, required=True, help="API Endpoint URL for the VLM model"
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--api_key", type=str, required=True, help="API Key for the VLM model"
|
| 114 |
+
)
|
| 115 |
+
return parser.parse_args()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def save_json_file(data, output_file):
|
| 119 |
+
"""Save data to a JSON file"""
|
| 120 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 121 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 122 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def extract_json_from_response(response, is_list=False):
|
| 126 |
+
"""Extract a JSON object or list from the model's response text"""
|
| 127 |
+
if not response:
|
| 128 |
+
return None
|
| 129 |
+
# The regular expression supports both JSON objects `{...}` and lists `[...]`
|
| 130 |
+
pattern = r"```json\s*([\{\[].*?[\]\}])\s*```"
|
| 131 |
+
match = re.search(pattern, response, re.DOTALL)
|
| 132 |
+
if match:
|
| 133 |
+
json_str = match.group(1)
|
| 134 |
+
try:
|
| 135 |
+
return json.loads(json_str)
|
| 136 |
+
except json.JSONDecodeError:
|
| 137 |
+
print(f"JSON parsing failed: {json_str}")
|
| 138 |
+
return None
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def calculate_metrics(results):
|
| 143 |
+
"""Calculate accuracy and other metrics from evaluation results"""
|
| 144 |
+
valid_results = [r for r in results if "error" not in r]
|
| 145 |
+
total_samples = len(valid_results)
|
| 146 |
+
if total_samples == 0: return {"accuracy": 0.0}
|
| 147 |
+
|
| 148 |
+
answered = sum(1 for x in valid_results if x.get("model_answer") is not None)
|
| 149 |
+
correct = sum(1 for x in valid_results if x.get("is_correct"))
|
| 150 |
+
accuracy = correct / answered if answered > 0 else 0.0
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"total_samples": total_samples,
|
| 154 |
+
"answered_samples": answered,
|
| 155 |
+
"correct_answers": correct,
|
| 156 |
+
"accuracy": accuracy,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def call_vlm_api(client, messages, model, item_id, max_retry_times, json_schema=None):
|
| 161 |
+
"""Call VLM API, with support for retries and structured output"""
|
| 162 |
+
params = {"model": model, "messages": messages, "max_tokens": 4096}
|
| 163 |
+
if json_schema:
|
| 164 |
+
params["response_format"] = {"type": "json_object", "schema": json_schema}
|
| 165 |
+
|
| 166 |
+
for retry in range(max_retry_times):
|
| 167 |
+
try:
|
| 168 |
+
completion = client.chat.completions.create(**params)
|
| 169 |
+
return completion.choices[0].message.content
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"API Error (item {item_id}): {e}. Retrying ({retry + 1}/{max_retry_times})...")
|
| 172 |
+
if retry == max_retry_times - 1:
|
| 173 |
+
raise e
|
| 174 |
+
time.sleep(5)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def generate_image(reference_image_id, prompt, all_frame_paths, output_dir, generation_idx):
|
| 178 |
+
"""Call the image generation API to create a new frame"""
|
| 179 |
+
print(f"\n[Image Generation] Using Prompt: '{prompt}'")
|
| 180 |
+
ark_api_key = os.environ.get("ARK_API_KEY")
|
| 181 |
+
if not ark_api_key:
|
| 182 |
+
raise ValueError("Environment variable ARK_API_KEY is not set.")
|
| 183 |
+
|
| 184 |
+
client = Ark(base_url="https://ark.cn-beijing.volces.com/api/v3", api_key=ark_api_key)
|
| 185 |
+
|
| 186 |
+
ref_image_path = all_frame_paths.get(reference_image_id)
|
| 187 |
+
if not ref_image_path or not os.path.exists(ref_image_path):
|
| 188 |
+
raise FileNotFoundError(f"Reference image ID {reference_image_id} not found.")
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
ref_image_b64 = encode_image(ref_image_path)
|
| 192 |
+
ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}"
|
| 193 |
+
|
| 194 |
+
response = client.images.generate(
|
| 195 |
+
model="doubao-seedream-4-0-250828",
|
| 196 |
+
prompt=prompt,
|
| 197 |
+
image=ref_image_data_uri,
|
| 198 |
+
size="1024x1024",
|
| 199 |
+
response_format="url",
|
| 200 |
+
watermark=False,
|
| 201 |
+
)
|
| 202 |
+
image_url = response.data[0].url
|
| 203 |
+
|
| 204 |
+
image_content = requests.get(image_url, timeout=60).content
|
| 205 |
+
|
| 206 |
+
new_frame_filename = f"generated_frame_{generation_idx}_ref_{reference_image_id}.jpg"
|
| 207 |
+
new_frame_path = os.path.join(output_dir, new_frame_filename)
|
| 208 |
+
|
| 209 |
+
with open(new_frame_path, "wb") as f:
|
| 210 |
+
f.write(image_content)
|
| 211 |
+
|
| 212 |
+
print(f"[Image Generation Success] Image saved to: {new_frame_path}")
|
| 213 |
+
return new_frame_path
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Image generation or download failed: {e}")
|
| 216 |
+
traceback.print_exc()
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
def retrieve_frames_by_image_embedding(
|
| 220 |
+
image_path, video_embeddings_data, request_queue, results_dict, k
|
| 221 |
+
):
|
| 222 |
+
"""Retrieve Top-K similar frames from the video using an image embedding"""
|
| 223 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 224 |
+
frame_filenames = video_embeddings_data["filenames"]
|
| 225 |
+
frame_embeddings = video_embeddings_data["embeddings"].to(device)
|
| 226 |
+
|
| 227 |
+
# 1. Send request to the embedding server process
|
| 228 |
+
request_id = str(uuid.uuid4())
|
| 229 |
+
request_queue.put((request_id, image_path))
|
| 230 |
+
|
| 231 |
+
# 2. Wait for the result
|
| 232 |
+
while request_id not in results_dict:
|
| 233 |
+
time.sleep(0.05)
|
| 234 |
+
query_embedding = results_dict.pop(request_id).to(device)
|
| 235 |
+
|
| 236 |
+
# 3. Calculate similarity and find Top-K frames
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
similarities = cosine_similarity(query_embedding, frame_embeddings)
|
| 239 |
+
top_k_indices = torch.topk(similarities, k=min(k, len(frame_filenames)), dim=-1).indices.cpu()
|
| 240 |
+
|
| 241 |
+
# Extract absolute paths for the frames from the filenames
|
| 242 |
+
video_frame_dir = os.path.dirname(frame_filenames[0])
|
| 243 |
+
top_k_paths = [os.path.join(video_frame_dir, video_embeddings_data['filenames'][i]) for i in top_k_indices]
|
| 244 |
+
|
| 245 |
+
return top_k_paths
|
| 246 |
+
|
| 247 |
+
def embedding_server_process(model_id, device, request_queue, results_dict):
|
| 248 |
+
"""
|
| 249 |
+
An independent server process that loads the SigLIP model and handles image embedding requests from worker processes.
|
| 250 |
+
"""
|
| 251 |
+
print(f"Embedding server started (PID: {os.getpid()})...")
|
| 252 |
+
model = AutoModel.from_pretrained(model_id).to(device).eval()
|
| 253 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 254 |
+
print("SigLIP model loaded in the embedding server.")
|
| 255 |
+
|
| 256 |
+
while True:
|
| 257 |
+
try:
|
| 258 |
+
request_id, image_path = request_queue.get()
|
| 259 |
+
if image_path == "STOP":
|
| 260 |
+
print("Embedding server received stop signal, shutting down.")
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
image = Image.open(image_path).convert("RGB")
|
| 265 |
+
inputs = processor(images=[image], return_tensors="pt").to(device)
|
| 266 |
+
image_features = model.get_image_features(**inputs)
|
| 267 |
+
results_dict[request_id] = image_features.cpu()
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"Error in embedding server: {e}")
|
| 271 |
+
traceback.print_exc()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def encode_image(image_path):
|
| 275 |
+
"""Encode an image file to a Base64 string"""
|
| 276 |
+
with open(image_path, "rb") as f:
|
| 277 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def uniformly_sample_frames_and_encode(frames_dir, num_frames):
|
| 281 |
+
"""Uniformly sample frames and encode them, while also returning a mapping of frame IDs to paths"""
|
| 282 |
+
if not os.path.isdir(frames_dir): return [], {}
|
| 283 |
+
|
| 284 |
+
frame_files = sorted(
|
| 285 |
+
[f for f in os.listdir(frames_dir) if f.endswith(".jpg")],
|
| 286 |
+
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
|
| 287 |
+
)
|
| 288 |
+
if not frame_files: return [], {}
|
| 289 |
+
|
| 290 |
+
indices = [int(i * len(frame_files) / num_frames) for i in range(num_frames)]
|
| 291 |
+
sampled_files = [frame_files[i] for i in indices]
|
| 292 |
+
|
| 293 |
+
frame_path_map, encoded_frames = {}, []
|
| 294 |
+
for f in sampled_files:
|
| 295 |
+
path = os.path.join(frames_dir, f)
|
| 296 |
+
frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1))
|
| 297 |
+
|
| 298 |
+
encoded_frames.extend([
|
| 299 |
+
{"type": "text", "text": f"This is Frame ID: {frame_id}"},
|
| 300 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}
|
| 301 |
+
])
|
| 302 |
+
frame_path_map[frame_id] = path
|
| 303 |
+
return encoded_frames, frame_path_map
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def run_workflow_for_item(
|
| 307 |
+
data_item, args, request_queue, results_dict
|
| 308 |
+
):
|
| 309 |
+
"""Execute the complete three-step workflow for a single data item"""
|
| 310 |
+
item_key = data_item["key"]
|
| 311 |
+
print(f"\n--- Starting processing for video: {item_key} ---")
|
| 312 |
+
|
| 313 |
+
# Create a separate output directory for each video's generated images
|
| 314 |
+
generated_images_dir = os.path.join(args.output_path, "generated_images", item_key)
|
| 315 |
+
os.makedirs(generated_images_dir, exist_ok=True)
|
| 316 |
+
|
| 317 |
+
# Initialize VLM client
|
| 318 |
+
if "ark" in args.base_url:
|
| 319 |
+
client = Ark(base_url=args.base_url, api_key=args.api_key)
|
| 320 |
+
elif "aliyun" in args.base_url or "127.0.0.1" in args.base_url:
|
| 321 |
+
client = OpenAI(api_key=args.api_key, base_url=args.base_url)
|
| 322 |
+
else:
|
| 323 |
+
client = AzureOpenAI(api_version="2023-05-15", api_key=args.api_key, azure_endpoint=args.base_url)
|
| 324 |
+
|
| 325 |
+
# --- Step 1: Initial understanding and generating "keyframe profile" requests ---
|
| 326 |
+
print(f"[{item_key}] Step 1: Uniformly sampling and generating keyframe creation requests...")
|
| 327 |
+
video_frames_path = os.path.join(args.frames_path, item_key)
|
| 328 |
+
initial_frames_encoded, initial_frame_paths = uniformly_sample_frames_and_encode(
|
| 329 |
+
video_frames_path, args.initial_frames_num
|
| 330 |
+
)
|
| 331 |
+
if not initial_frames_encoded:
|
| 332 |
+
raise FileNotFoundError(f"Initial frames not found for video {item_key}.")
|
| 333 |
+
|
| 334 |
+
planning_messages = [
|
| 335 |
+
{"role": "system", "content": STEP_1_PLANNING_PROMPT},
|
| 336 |
+
{"role": "user", "content": [
|
| 337 |
+
{"type": "text", "text": "Here are the initial sample frames and the question:"},
|
| 338 |
+
*initial_frames_encoded,
|
| 339 |
+
{"type": "text", "text": f"Question: {data_item['question']}"}
|
| 340 |
+
]}
|
| 341 |
+
]
|
| 342 |
+
|
| 343 |
+
# Define JSON Schema for structured output
|
| 344 |
+
planning_schema = {
|
| 345 |
+
"type": "array",
|
| 346 |
+
"items": {
|
| 347 |
+
"type": "object",
|
| 348 |
+
"properties": {
|
| 349 |
+
"reference_image_id": {"type": "integer"},
|
| 350 |
+
"prompt": {"type": "string"}
|
| 351 |
+
},
|
| 352 |
+
"required": ["reference_image_id", "prompt"]
|
| 353 |
+
}
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
raw_planning_response = call_vlm_api(client, planning_messages, args.target_model, item_key, args.max_retry_times)
|
| 357 |
+
image_generation_requests = extract_json_from_response(raw_planning_response, is_list=True)
|
| 358 |
+
|
| 359 |
+
if not image_generation_requests or not isinstance(image_generation_requests, list):
|
| 360 |
+
raise ValueError(f"Step 1 failed to generate valid JSON-formatted image generation requests. Response: {raw_planning_response}")
|
| 361 |
+
|
| 362 |
+
print(f"[{item_key}] Successfully generated {len(image_generation_requests)} keyframe generation requests.")
|
| 363 |
+
|
| 364 |
+
# --- Validate and correct reference image IDs ---
|
| 365 |
+
valid_ids = list(initial_frame_paths.keys())
|
| 366 |
+
if not valid_ids:
|
| 367 |
+
raise ValueError(f"No valid initial frame IDs found for video {item_key}.")
|
| 368 |
+
|
| 369 |
+
for req in image_generation_requests:
|
| 370 |
+
original_id = req.get("reference_image_id")
|
| 371 |
+
if original_id not in valid_ids:
|
| 372 |
+
closest_id = min(valid_ids, key=lambda valid_id: abs(valid_id - original_id))
|
| 373 |
+
print(f"Warning: Model generated a non-existent reference_image_id: {original_id}. Substituting with the closest valid ID: {closest_id}.")
|
| 374 |
+
req["reference_image_id"] = closest_id
|
| 375 |
+
|
| 376 |
+
# --- Step 2: Generate images and perform similarity retrieval ---
|
| 377 |
+
print(f"[{item_key}] Step 2: Generating images and retrieving similar frames...")
|
| 378 |
+
all_retrieved_frame_paths = set()
|
| 379 |
+
generated_image_paths = []
|
| 380 |
+
video_embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt")
|
| 381 |
+
if not os.path.exists(video_embedding_file):
|
| 382 |
+
raise FileNotFoundError(f"Embedding file for video {item_key} not found: {video_embedding_file}")
|
| 383 |
+
video_embeddings_data = torch.load(video_embedding_file, map_location="cpu")
|
| 384 |
+
|
| 385 |
+
# Correct path issue, ensure filenames in the embedding file are absolute paths
|
| 386 |
+
video_frame_dir_for_embeddings = os.path.join(args.frames_path, item_key)
|
| 387 |
+
video_embeddings_data['filenames'] = [os.path.join(video_frame_dir_for_embeddings, os.path.basename(f)) for f in video_embeddings_data['filenames']]
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
for i, req in enumerate(image_generation_requests):
|
| 391 |
+
# 2a. Generate image
|
| 392 |
+
generated_path = generate_image(
|
| 393 |
+
reference_image_id=req["reference_image_id"],
|
| 394 |
+
prompt=req["prompt"],
|
| 395 |
+
all_frame_paths=initial_frame_paths,
|
| 396 |
+
output_dir=generated_images_dir,
|
| 397 |
+
generation_idx=i + 1,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
path_for_retrieval = None
|
| 401 |
+
if generated_path:
|
| 402 |
+
generated_image_paths.append(generated_path)
|
| 403 |
+
path_for_retrieval = generated_path
|
| 404 |
+
else:
|
| 405 |
+
print(f"Warning: Generation failed for image {i+1}. Using its reference image (ID: {req['reference_image_id']}) for retrieval instead.")
|
| 406 |
+
path_for_retrieval = initial_frame_paths.get(req["reference_image_id"])
|
| 407 |
+
|
| 408 |
+
if not path_for_retrieval:
|
| 409 |
+
print(f"Error: Could not find a path for retrieval for request {i+1}. Skipping.")
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
# 2b. Retrieve frames via image embedding
|
| 413 |
+
retrieved_paths = retrieve_frames_by_image_embedding(
|
| 414 |
+
path_for_retrieval, video_embeddings_data, request_queue, results_dict, k=TOP_K_FRAMES
|
| 415 |
+
)
|
| 416 |
+
all_retrieved_frame_paths.update(retrieved_paths)
|
| 417 |
+
print(f"[{item_key}] Retrieval {i+1}/{len(image_generation_requests)} complete, found {len(retrieved_paths)} frames.")
|
| 418 |
+
|
| 419 |
+
if not all_retrieved_frame_paths:
|
| 420 |
+
raise ValueError(f"Failed to retrieve any frames for video {item_key}.")
|
| 421 |
+
|
| 422 |
+
print(f"[{item_key}] Step 2 complete. Retrieved a total of {len(all_retrieved_frame_paths)} unique keyframes.")
|
| 423 |
+
|
| 424 |
+
# --- Step 3: Consolidate keyframes for final reasoning ---
|
| 425 |
+
print(f"[{item_key}] Step 3: Consolidating keyframes for final reasoning...")
|
| 426 |
+
final_frames_encoded = []
|
| 427 |
+
for path in sorted(list(all_retrieved_frame_paths)):
|
| 428 |
+
final_frames_encoded.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}})
|
| 429 |
+
|
| 430 |
+
final_messages = [
|
| 431 |
+
{"role": "system", "content": STEP_3_FINAL_ANSWER_PROMPT},
|
| 432 |
+
{"role": "user", "content": [
|
| 433 |
+
{"type": "text", "text": "Here are all the keyframes retrieved for you. Please answer the question based on them."},
|
| 434 |
+
*final_frames_encoded,
|
| 435 |
+
{"type": "text", "text": f"Question: {data_item['question']}"}
|
| 436 |
+
]}
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
final_response_text = call_vlm_api(client, final_messages, args.target_model, item_key, args.max_retry_times)
|
| 440 |
+
|
| 441 |
+
# --- Consolidating Results ---
|
| 442 |
+
parsed_answer = extract_json_from_response(final_response_text)
|
| 443 |
+
model_answer = parsed_answer.get("answer", "").strip().upper() if parsed_answer else None
|
| 444 |
+
is_correct = (model_answer == data_item["answer"].strip().upper()) if model_answer else False
|
| 445 |
+
|
| 446 |
+
result = {
|
| 447 |
+
**data_item,
|
| 448 |
+
"workflow_steps": {
|
| 449 |
+
"step1_planning_requests": image_generation_requests,
|
| 450 |
+
"step2_generated_images": generated_image_paths,
|
| 451 |
+
"step2_retrieved_frame_paths": sorted(list(all_retrieved_frame_paths)),
|
| 452 |
+
"step3_final_reasoning_and_answer": final_response_text,
|
| 453 |
+
},
|
| 454 |
+
"model_answer": model_answer,
|
| 455 |
+
"is_correct": is_correct,
|
| 456 |
+
}
|
| 457 |
+
return result
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def process_single_data_wrapper(data_item, args, request_queue, results_dict):
|
| 461 |
+
"""Wrapper function to process a single data item, used for exception handling"""
|
| 462 |
+
try:
|
| 463 |
+
return run_workflow_for_item(data_item, args, request_queue, results_dict)
|
| 464 |
+
except Exception as e:
|
| 465 |
+
print(f"\nA critical error occurred while processing video {data_item['key']}: {e}")
|
| 466 |
+
traceback.print_exc()
|
| 467 |
+
return {
|
| 468 |
+
"key": data_item['key'],
|
| 469 |
+
"uid": data_item.get('uid'),
|
| 470 |
+
"error": str(e),
|
| 471 |
+
"traceback": traceback.format_exc(),
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
def main():
|
| 475 |
+
"""Main function to orchestrate the entire evaluation workflow"""
|
| 476 |
+
args = parse_arguments()
|
| 477 |
+
print("--- Image Retrieval-based Video QA Workflow Starting ---")
|
| 478 |
+
print(f"Evaluating Model: {args.target_model}, Dataset: {args.data_file}")
|
| 479 |
+
|
| 480 |
+
try:
|
| 481 |
+
multiprocessing.set_start_method("spawn", force=True)
|
| 482 |
+
except RuntimeError:
|
| 483 |
+
pass # Start method already set
|
| 484 |
+
|
| 485 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 486 |
+
|
| 487 |
+
# Define output file paths
|
| 488 |
+
model_safe_name = args.target_model.replace("/", "_")
|
| 489 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 490 |
+
output_prefix = f"{model_safe_name}_{data_filename_base}_image_retrieval_{args.initial_frames_num}frames"
|
| 491 |
+
|
| 492 |
+
results_file = os.path.join(args.output_path, f"{output_prefix}_results.json")
|
| 493 |
+
metrics_file = os.path.join(args.output_path, f"{output_prefix}_metrics.json")
|
| 494 |
+
|
| 495 |
+
test_data = load_test_data(args.data_file)
|
| 496 |
+
all_results = []
|
| 497 |
+
|
| 498 |
+
with multiprocessing.Manager() as manager:
|
| 499 |
+
request_queue = manager.Queue()
|
| 500 |
+
results_dict = manager.dict()
|
| 501 |
+
|
| 502 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 503 |
+
embedding_server = multiprocessing.Process(
|
| 504 |
+
target=embedding_server_process,
|
| 505 |
+
args=(SIGLIP_MODEL_ID, device, request_queue, results_dict),
|
| 506 |
+
)
|
| 507 |
+
embedding_server.start()
|
| 508 |
+
|
| 509 |
+
# Wait for the embedding server model to load
|
| 510 |
+
time.sleep(15)
|
| 511 |
+
|
| 512 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=args.pool_processes) as executor:
|
| 513 |
+
func = partial(
|
| 514 |
+
process_single_data_wrapper,
|
| 515 |
+
args=args,
|
| 516 |
+
request_queue=request_queue,
|
| 517 |
+
results_dict=results_dict
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
results_iterator = executor.map(func, test_data)
|
| 521 |
+
|
| 522 |
+
for result in tqdm(results_iterator, total=len(test_data), desc="Processing Videos"):
|
| 523 |
+
if result:
|
| 524 |
+
all_results.append(result)
|
| 525 |
+
# Save results every 10 videos to prevent data loss from interruptions
|
| 526 |
+
if len(all_results) % 10 == 0:
|
| 527 |
+
save_json_file(all_results, results_file)
|
| 528 |
+
|
| 529 |
+
# Gracefully shut down the embedding server
|
| 530 |
+
print("All tasks completed. Shutting down the embedding server...")
|
| 531 |
+
request_queue.put((None, "STOP"))
|
| 532 |
+
embedding_server.join()
|
| 533 |
+
|
| 534 |
+
print("\n--- All Videos Processed ---")
|
| 535 |
+
save_json_file(all_results, results_file)
|
| 536 |
+
print(f"Detailed results saved to: {results_file}")
|
| 537 |
+
|
| 538 |
+
final_metrics = calculate_metrics(all_results)
|
| 539 |
+
save_json_file(final_metrics, metrics_file)
|
| 540 |
+
print(f"Final evaluation metrics saved to: {metrics_file}")
|
| 541 |
+
print(json.dumps(final_metrics, indent=4))
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
if __name__ == "__main__":
|
| 545 |
+
# Before running, please ensure you have set the API Key for the image generation service
|
| 546 |
+
# export ARK_API_KEY="YOUR_VOLCENGINE_ARK_API_KEY"
|
| 547 |
+
main()
|
| 548 |
+
|
main_mcot.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import traceback
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from functools import partial
|
| 10 |
+
import requests # Import requests library to download images from URLs
|
| 11 |
+
from openai import AzureOpenAI, OpenAI
|
| 12 |
+
from volcenginesdkarkruntime import Ark
|
| 13 |
+
import concurrent.futures
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
# 1. New Agent System Prompt
|
| 17 |
+
# Defines the agent's role and principles, guiding it to use the "imagination" tool when visual evidence is insufficient.
|
| 18 |
+
IMAGINE_AGENT_SYSTEM_PROMPT = """
|
| 19 |
+
You are an intelligent AI assistant specializing in answering video question-answering problems through reasoning and imagination.
|
| 20 |
+
Your task is to answer a multiple-choice question based on an initial, limited set of video frames.
|
| 21 |
+
|
| 22 |
+
You will receive a few uniformly sampled frames to get a basic understanding of the video.
|
| 23 |
+
These frames may not contain all the visual evidence needed to directly answer the question.
|
| 24 |
+
|
| 25 |
+
If the provided frame information is insufficient, you must use the `imagine_frame` tool to generate new, imagined frames to fill in the visual gaps and aid your reasoning.
|
| 26 |
+
You can call this tool multiple times to construct a sequence of imagined events.
|
| 27 |
+
|
| 28 |
+
Your strategy should be:
|
| 29 |
+
1. Analyze the initial frames and the user's question.
|
| 30 |
+
2. Form a hypothesis about the missing content.
|
| 31 |
+
3. If you need more visual information, call the `imagine_frame` tool. Provide a text `prompt` describing the scene you want to imagine, and select a `reference_image_id` from existing frames. The `reference_image_id` MUST be one of the IDs explicitly provided to you in the conversation history (e.g., "Frame ID: X" or "New Frame ID: Y"). Do not invent or assume frame IDs.
|
| 32 |
+
4. Analyze the newly generated frame in conjunction with the existing ones.
|
| 33 |
+
5. Continue this process of reasoning and imagination until you are confident in your answer. Please ensure you have found or created the relevant visual cues before answering the question.
|
| 34 |
+
6. Each tool call can only generate one frame.
|
| 35 |
+
|
| 36 |
+
IMPORTANT: Your text `prompt` for image generation must be safe and general. Avoid descriptions that could be interpreted as sensitive, harmful, or explicit to prevent generation failures.
|
| 37 |
+
|
| 38 |
+
After your reasoning, provide the final answer in a JSON code block. The JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'.
|
| 39 |
+
|
| 40 |
+
Your output must strictly follow this format:
|
| 41 |
+
<Your step-by-step reasoning process here, including why you chose to imagine a certain frame>
|
| 42 |
+
```json
|
| 43 |
+
{"answer": "X"}
|
| 44 |
+
```
|
| 45 |
+
Do not include any other text after the JSON code block.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# 2. New Tool Schema for imagine_frame
|
| 49 |
+
# Defines the interface, parameters, and description for the `imagine_frame` tool.
|
| 50 |
+
IMAGINE_FRAME_TOOL_SCHEMA = {
|
| 51 |
+
"type": "function",
|
| 52 |
+
"function": {
|
| 53 |
+
"name": "imagine_frame",
|
| 54 |
+
"description": "When visual evidence is insufficient, generates a new image based on a text prompt and a reference image to help answer the question. Use it to imagine what might have happened between the provided frames.",
|
| 55 |
+
"parameters": {
|
| 56 |
+
"type": "object",
|
| 57 |
+
"properties": {
|
| 58 |
+
"reference_image_id": {
|
| 59 |
+
"type": "integer",
|
| 60 |
+
"description": "The ID of an existing frame to use as a style and content reference. It can be one of the original frames or a previously generated one.",
|
| 61 |
+
},
|
| 62 |
+
"prompt": {
|
| 63 |
+
"type": "string",
|
| 64 |
+
"description": "A detailed text description of the frame you want to imagine and generate.",
|
| 65 |
+
},
|
| 66 |
+
},
|
| 67 |
+
"required": ["reference_image_id", "prompt"],
|
| 68 |
+
},
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# 3. Implementation of the `imagine_frame` tool
|
| 74 |
+
def imagine_frame(
|
| 75 |
+
reference_image_id: int,
|
| 76 |
+
prompt: str,
|
| 77 |
+
all_frame_paths: dict,
|
| 78 |
+
output_dir: str,
|
| 79 |
+
generation_count: int,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Tool implementation: Calls an image generation model to create a new frame.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
reference_image_id (int): The ID of the reference frame.
|
| 86 |
+
prompt (str): The text prompt for image generation.
|
| 87 |
+
all_frame_paths (dict): A dictionary containing IDs and paths of all currently available frames (original + generated).
|
| 88 |
+
output_dir (str): The directory to save the generated image.
|
| 89 |
+
generation_count (int): The current generation count, used for naming the file.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
str or None: The path of the newly generated image on success, otherwise None.
|
| 93 |
+
"""
|
| 94 |
+
print(f"\n[Tool Call] Imagining new frame with prompt: '{prompt}'")
|
| 95 |
+
ark_api_key = os.environ.get("ARK_API_KEY")
|
| 96 |
+
if not ark_api_key:
|
| 97 |
+
raise ValueError("Error: Environment variable ARK_API_KEY is not set.")
|
| 98 |
+
|
| 99 |
+
client = Ark(
|
| 100 |
+
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
| 101 |
+
api_key=ark_api_key,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
ref_image_path = all_frame_paths.get(reference_image_id)
|
| 105 |
+
if not ref_image_path or not os.path.exists(ref_image_path):
|
| 106 |
+
raise FileNotFoundError(f"Reference image ID not found: {reference_image_id}")
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# Encode the reference image to a Base64 Data URI
|
| 110 |
+
ref_image_b64 = encode_image(ref_image_path)
|
| 111 |
+
ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}"
|
| 112 |
+
|
| 113 |
+
imagesResponse = client.images.generate(
|
| 114 |
+
model="doubao-seedream-4-0-250828",
|
| 115 |
+
prompt=prompt,
|
| 116 |
+
image=ref_image_data_uri,
|
| 117 |
+
size="1024x1024", # Can be adjusted as needed, e.g., "2K"
|
| 118 |
+
response_format="url",
|
| 119 |
+
watermark=False,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
image_url = imagesResponse.data[0].url
|
| 123 |
+
|
| 124 |
+
# Download the image from the URL
|
| 125 |
+
response = requests.get(image_url)
|
| 126 |
+
response.raise_for_status()
|
| 127 |
+
|
| 128 |
+
# Save the image to the specified directory
|
| 129 |
+
new_frame_filename = (
|
| 130 |
+
f"generated_frame_{generation_count}_ref_{reference_image_id}.jpg"
|
| 131 |
+
)
|
| 132 |
+
new_frame_path = os.path.join(output_dir, new_frame_filename)
|
| 133 |
+
|
| 134 |
+
with open(new_frame_path, "wb") as f:
|
| 135 |
+
f.write(response.content)
|
| 136 |
+
|
| 137 |
+
print(f"[Tool Success] Generated frame saved to: {new_frame_path}")
|
| 138 |
+
return new_frame_path
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"An error occurred during image generation or download: {e}")
|
| 142 |
+
traceback.print_exc()
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def parse_arguments():
|
| 147 |
+
"""Parse command-line arguments"""
|
| 148 |
+
parser = argparse.ArgumentParser(
|
| 149 |
+
description="Video QA Evaluation Framework with Imagine-and-Reason Agent"
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--target-model",
|
| 153 |
+
"-tm",
|
| 154 |
+
type=str,
|
| 155 |
+
required=True,
|
| 156 |
+
help="The model to be evaluated (e.g., gpt-4o)",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--frames-path",
|
| 160 |
+
"-fp",
|
| 161 |
+
type=str,
|
| 162 |
+
required=True,
|
| 163 |
+
help="Absolute path to the root directory containing video frames.",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--output-path",
|
| 167 |
+
"-op",
|
| 168 |
+
type=str,
|
| 169 |
+
default="./generated_outputs",
|
| 170 |
+
help="Path to store generated images and results.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--data-file",
|
| 174 |
+
"-df",
|
| 175 |
+
type=str,
|
| 176 |
+
required=True,
|
| 177 |
+
help="Absolute path to the evaluation dataset JSON file.",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--initial-frames-num",
|
| 181 |
+
"-ifn",
|
| 182 |
+
type=int,
|
| 183 |
+
default=8,
|
| 184 |
+
help="Number of initial uniformly sampled frames.",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--max-retry-times",
|
| 188 |
+
"-mr",
|
| 189 |
+
type=int,
|
| 190 |
+
default=10,
|
| 191 |
+
help="Maximum number of retries for failed API calls.",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--pool-processes",
|
| 195 |
+
"-pp",
|
| 196 |
+
type=int,
|
| 197 |
+
default=10,
|
| 198 |
+
help="Number of parallel processes.",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--base_url",
|
| 202 |
+
type=str,
|
| 203 |
+
required=True,
|
| 204 |
+
help="API Endpoint URL for the target model service.",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--api_key",
|
| 208 |
+
type=str,
|
| 209 |
+
required=True,
|
| 210 |
+
help="API Key for the target model service.",
|
| 211 |
+
)
|
| 212 |
+
return parser.parse_args()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def save_json_file(data, output_file):
|
| 216 |
+
"""Save data to a JSON file"""
|
| 217 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 218 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def extract_json_from_response(response):
|
| 222 |
+
"""Extract JSON answer from the model's text response"""
|
| 223 |
+
if not response:
|
| 224 |
+
return None
|
| 225 |
+
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 226 |
+
if match:
|
| 227 |
+
try:
|
| 228 |
+
return json.loads(match.group(1))
|
| 229 |
+
except (json.JSONDecodeError, IndexError):
|
| 230 |
+
return None
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def calculate_metrics(results):
|
| 235 |
+
"""Calculate various metrics from the evaluation results"""
|
| 236 |
+
valid_results = [r for r in results if "error" not in r]
|
| 237 |
+
total_samples = len(valid_results)
|
| 238 |
+
if total_samples == 0:
|
| 239 |
+
return {
|
| 240 |
+
"total_samples": 0,
|
| 241 |
+
"answered_samples": 0,
|
| 242 |
+
"correct_answers": 0,
|
| 243 |
+
"accuracy": 0.0,
|
| 244 |
+
}
|
| 245 |
+
answered_samples = sum(
|
| 246 |
+
1 for x in valid_results if x.get("model_answer") is not None
|
| 247 |
+
)
|
| 248 |
+
correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
|
| 249 |
+
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
|
| 250 |
+
return {
|
| 251 |
+
"total_samples": total_samples,
|
| 252 |
+
"answered_samples": answered_samples,
|
| 253 |
+
"correct_answers": correct_answers,
|
| 254 |
+
"accuracy": accuracy,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
|
| 259 |
+
"""A single model API call with retry logic"""
|
| 260 |
+
params = {"model": model, "messages": messages, "max_tokens": 4096}
|
| 261 |
+
if tools:
|
| 262 |
+
params["tools"] = tools
|
| 263 |
+
params["tool_choice"] = "auto"
|
| 264 |
+
|
| 265 |
+
retry_times = 0
|
| 266 |
+
while retry_times < max_retry_times:
|
| 267 |
+
try:
|
| 268 |
+
completion = client.chat.completions.create(**params)
|
| 269 |
+
return completion.choices[0].message
|
| 270 |
+
except Exception as e:
|
| 271 |
+
retry_times += 1
|
| 272 |
+
print(
|
| 273 |
+
f"API call error (Item {item_id}): {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
|
| 274 |
+
)
|
| 275 |
+
if retry_times == max_retry_times:
|
| 276 |
+
raise e
|
| 277 |
+
time.sleep(5)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def uniformly_sample_frames_and_encode(frames_dir, num_frames):
|
| 281 |
+
"""Uniformly sample a specified number of frames from a directory and encode them"""
|
| 282 |
+
if not os.path.isdir(frames_dir):
|
| 283 |
+
return [], {}
|
| 284 |
+
|
| 285 |
+
frame_files = sorted(
|
| 286 |
+
[f for f in os.listdir(frames_dir) if f.endswith(".jpg")],
|
| 287 |
+
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
total_frames = len(frame_files)
|
| 291 |
+
if total_frames == 0:
|
| 292 |
+
return [], {}
|
| 293 |
+
|
| 294 |
+
if total_frames > num_frames:
|
| 295 |
+
indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
|
| 296 |
+
sampled_files = [frame_files[i] for i in indices]
|
| 297 |
+
else:
|
| 298 |
+
sampled_files = frame_files
|
| 299 |
+
|
| 300 |
+
frame_path_map = {}
|
| 301 |
+
encoded_frames = []
|
| 302 |
+
for f in sampled_files:
|
| 303 |
+
path = os.path.join(frames_dir, f)
|
| 304 |
+
frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1))
|
| 305 |
+
b64_image = encode_image(path)
|
| 306 |
+
# Send frame ID and image content as a pair
|
| 307 |
+
encoded_frames.append({"type": "text", "text": f"This is Frame ID: {frame_id}"})
|
| 308 |
+
encoded_frames.append(
|
| 309 |
+
{
|
| 310 |
+
"type": "image_url",
|
| 311 |
+
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
|
| 312 |
+
}
|
| 313 |
+
)
|
| 314 |
+
frame_path_map[frame_id] = path
|
| 315 |
+
|
| 316 |
+
return encoded_frames, frame_path_map
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def evaluate_single_item_agentic_imagination(
|
| 320 |
+
data_item,
|
| 321 |
+
initial_frames,
|
| 322 |
+
initial_frame_paths,
|
| 323 |
+
generated_images_dir,
|
| 324 |
+
target_model,
|
| 325 |
+
api_key,
|
| 326 |
+
base_url,
|
| 327 |
+
max_retry_times,
|
| 328 |
+
):
|
| 329 |
+
"""
|
| 330 |
+
Core logic for evaluating a single data item using the Imagine-and-Reason Agent.
|
| 331 |
+
"""
|
| 332 |
+
# 4. New Agent Loop
|
| 333 |
+
if "ark" in base_url:
|
| 334 |
+
client = Ark(base_url=base_url, api_key=api_key)
|
| 335 |
+
elif "aliyun" in base_url or "127.0.0.1" in base_url:
|
| 336 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 337 |
+
else:
|
| 338 |
+
client = AzureOpenAI(
|
| 339 |
+
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
tools = [IMAGINE_FRAME_TOOL_SCHEMA]
|
| 343 |
+
|
| 344 |
+
# Store paths of all available frames (initial + generated) in a dictionary for reference
|
| 345 |
+
available_frame_paths = initial_frame_paths.copy()
|
| 346 |
+
|
| 347 |
+
initial_prompt_content = [
|
| 348 |
+
{
|
| 349 |
+
"type": "text",
|
| 350 |
+
"text": "Here are the initial sampled video frames provided to you:",
|
| 351 |
+
},
|
| 352 |
+
*initial_frames,
|
| 353 |
+
{
|
| 354 |
+
"type": "text",
|
| 355 |
+
"text": f"Please answer the following question:\n{data_item['question']}",
|
| 356 |
+
},
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
messages = [
|
| 360 |
+
{"role": "system", "content": IMAGINE_AGENT_SYSTEM_PROMPT},
|
| 361 |
+
{"role": "user", "content": initial_prompt_content},
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
response_content = None
|
| 365 |
+
max_tool_calls = (
|
| 366 |
+
5 # Limit the number of times the agent can imagine to prevent infinite loops
|
| 367 |
+
)
|
| 368 |
+
generation_count = 0
|
| 369 |
+
|
| 370 |
+
for i in range(max_tool_calls):
|
| 371 |
+
response_message = call_single_model(
|
| 372 |
+
client,
|
| 373 |
+
messages,
|
| 374 |
+
target_model,
|
| 375 |
+
data_item["key"],
|
| 376 |
+
max_retry_times,
|
| 377 |
+
tools=tools,
|
| 378 |
+
)
|
| 379 |
+
if response_message is None:
|
| 380 |
+
return None
|
| 381 |
+
|
| 382 |
+
messages.append(response_message.model_dump(exclude_none=True))
|
| 383 |
+
|
| 384 |
+
if response_message.tool_calls:
|
| 385 |
+
tool_call = response_message.tool_calls[
|
| 386 |
+
0
|
| 387 |
+
] # Process one tool call at a time
|
| 388 |
+
function_name = tool_call.function.name
|
| 389 |
+
|
| 390 |
+
if function_name == "imagine_frame":
|
| 391 |
+
generation_count += 1
|
| 392 |
+
function_args = json.loads(tool_call.function.arguments)
|
| 393 |
+
new_frame_path = imagine_frame(
|
| 394 |
+
**function_args,
|
| 395 |
+
all_frame_paths=available_frame_paths,
|
| 396 |
+
output_dir=generated_images_dir,
|
| 397 |
+
generation_count=generation_count,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if new_frame_path:
|
| 401 |
+
# Create a unique ID for the newly generated frame
|
| 402 |
+
new_frame_id = (
|
| 403 |
+
max(available_frame_paths.keys())
|
| 404 |
+
if available_frame_paths
|
| 405 |
+
else 0
|
| 406 |
+
) + 1
|
| 407 |
+
available_frame_paths[new_frame_id] = new_frame_path
|
| 408 |
+
|
| 409 |
+
b64_image = encode_image(new_frame_path)
|
| 410 |
+
tool_response_content = [
|
| 411 |
+
{
|
| 412 |
+
"type": "text",
|
| 413 |
+
"text": f"Here is the frame you requested to imagine (New Frame ID: {new_frame_id}). Please use it to continue your reasoning.",
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"type": "image_url",
|
| 417 |
+
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
|
| 418 |
+
},
|
| 419 |
+
]
|
| 420 |
+
|
| 421 |
+
messages.append(
|
| 422 |
+
{
|
| 423 |
+
"tool_call_id": tool_call.id,
|
| 424 |
+
"role": "tool",
|
| 425 |
+
"name": function_name,
|
| 426 |
+
"content": json.dumps(
|
| 427 |
+
{"status": "success", "new_frame_id": new_frame_id}
|
| 428 |
+
),
|
| 429 |
+
}
|
| 430 |
+
)
|
| 431 |
+
messages.append({"role": "user", "content": tool_response_content})
|
| 432 |
+
else: # Tool execution failed
|
| 433 |
+
messages.append(
|
| 434 |
+
{
|
| 435 |
+
"tool_call_id": tool_call.id,
|
| 436 |
+
"role": "tool",
|
| 437 |
+
"name": function_name,
|
| 438 |
+
"content": json.dumps(
|
| 439 |
+
{
|
| 440 |
+
"status": "error",
|
| 441 |
+
"message": "Failed to generate image.",
|
| 442 |
+
}
|
| 443 |
+
),
|
| 444 |
+
}
|
| 445 |
+
)
|
| 446 |
+
else: # No tool call means the model is ready to give a final answer
|
| 447 |
+
response_content = response_message.content
|
| 448 |
+
break
|
| 449 |
+
|
| 450 |
+
# If the max number of calls is reached without an answer, force a final response
|
| 451 |
+
if response_content is None and response_message:
|
| 452 |
+
final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far."
|
| 453 |
+
messages.append({"role": "user", "content": final_prompt})
|
| 454 |
+
final_response_message = call_single_model(
|
| 455 |
+
client, messages, target_model, data_item["key"], max_retry_times
|
| 456 |
+
)
|
| 457 |
+
if final_response_message:
|
| 458 |
+
messages.append(final_response_message.model_dump(exclude_none=True))
|
| 459 |
+
response_content = final_response_message.content
|
| 460 |
+
|
| 461 |
+
is_correct = False
|
| 462 |
+
model_answer_cleaned = None
|
| 463 |
+
parsed_json = extract_json_from_response(response_content)
|
| 464 |
+
if parsed_json and "answer" in parsed_json:
|
| 465 |
+
model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
|
| 466 |
+
gold_answer = data_item["answer"].strip().upper()
|
| 467 |
+
if model_answer_cleaned == gold_answer:
|
| 468 |
+
is_correct = True
|
| 469 |
+
|
| 470 |
+
return {
|
| 471 |
+
**data_item,
|
| 472 |
+
"agent_conversation": messages,
|
| 473 |
+
"model_reasoning_and_answer": response_content,
|
| 474 |
+
"model_answer": model_answer_cleaned,
|
| 475 |
+
"is_correct": is_correct,
|
| 476 |
+
"generated_images_path": generated_images_dir, # 5. Store the path to intermediate generated images
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def encode_image(image_path):
|
| 481 |
+
"""Encode an image file to a Base64 string"""
|
| 482 |
+
with open(image_path, "rb") as image_file:
|
| 483 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def process_single_data(data_item, args):
|
| 487 |
+
"""Worker function to process a single data item in parallel"""
|
| 488 |
+
item_key = data_item["key"]
|
| 489 |
+
try:
|
| 490 |
+
# Create a separate subfolder for each video's generated images
|
| 491 |
+
generated_images_dir = os.path.join(
|
| 492 |
+
args.output_path, "generated_images", item_key
|
| 493 |
+
)
|
| 494 |
+
os.makedirs(generated_images_dir, exist_ok=True)
|
| 495 |
+
|
| 496 |
+
specific_frames_path = os.path.join(args.frames_path, item_key)
|
| 497 |
+
initial_frames, initial_frame_paths = uniformly_sample_frames_and_encode(
|
| 498 |
+
specific_frames_path, args.initial_frames_num
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if not initial_frames:
|
| 502 |
+
raise FileNotFoundError(f"Initial frames not found for item '{item_key}'")
|
| 503 |
+
|
| 504 |
+
result = evaluate_single_item_agentic_imagination(
|
| 505 |
+
data_item,
|
| 506 |
+
initial_frames,
|
| 507 |
+
initial_frame_paths,
|
| 508 |
+
generated_images_dir,
|
| 509 |
+
args.target_model,
|
| 510 |
+
args.api_key,
|
| 511 |
+
args.base_url,
|
| 512 |
+
args.max_retry_times,
|
| 513 |
+
)
|
| 514 |
+
return result
|
| 515 |
+
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print(f"\nA critical error occurred while processing item {item_key}: {str(e)}")
|
| 518 |
+
traceback.print_exc()
|
| 519 |
+
return {
|
| 520 |
+
"key": item_key,
|
| 521 |
+
"uid": data_item.get("uid"),
|
| 522 |
+
"error": str(e),
|
| 523 |
+
"traceback": traceback.format_exc(),
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def load_test_data(json_file):
|
| 528 |
+
"""Load test data from a JSON file"""
|
| 529 |
+
try:
|
| 530 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 531 |
+
return json.load(f)
|
| 532 |
+
except FileNotFoundError:
|
| 533 |
+
print(f"Error: Data file not found: {json_file}")
|
| 534 |
+
exit(1)
|
| 535 |
+
except json.JSONDecodeError:
|
| 536 |
+
print(f"Error: JSON file is malformed: {json_file}")
|
| 537 |
+
exit(1)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def main():
|
| 541 |
+
"""Main function to orchestrate the entire evaluation flow"""
|
| 542 |
+
args = parse_arguments()
|
| 543 |
+
|
| 544 |
+
print("--- Video QA Imagine-and-Reason Agent Framework ---")
|
| 545 |
+
print(f"Evaluating Model: {args.target_model}")
|
| 546 |
+
print(f"Output Path: {args.output_path}")
|
| 547 |
+
print(f"Dataset: {args.data_file}")
|
| 548 |
+
print("---------------------------------")
|
| 549 |
+
|
| 550 |
+
# Create the main output directory
|
| 551 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 552 |
+
|
| 553 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 554 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 555 |
+
|
| 556 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_imagine_agent"
|
| 557 |
+
results_output_file = os.path.join(
|
| 558 |
+
args.output_path, f"{output_prefix}_results.json"
|
| 559 |
+
)
|
| 560 |
+
metrics_output_file = os.path.join(
|
| 561 |
+
args.output_path, f"{output_prefix}_metrics.json"
|
| 562 |
+
)
|
| 563 |
+
error_log_file = os.path.join(args.output_path, f"{output_prefix}_errors.log")
|
| 564 |
+
|
| 565 |
+
# The logic for resuming from a checkpoint can be added here, same as in the first script
|
| 566 |
+
|
| 567 |
+
all_test_data = load_test_data(args.data_file)
|
| 568 |
+
tasks_to_process = all_test_data
|
| 569 |
+
|
| 570 |
+
all_results = []
|
| 571 |
+
# Use ProcessPoolExecutor for parallel processing
|
| 572 |
+
with concurrent.futures.ProcessPoolExecutor(
|
| 573 |
+
max_workers=args.pool_processes
|
| 574 |
+
) as executor:
|
| 575 |
+
func = partial(process_single_data, args=args)
|
| 576 |
+
results_iterator = executor.map(func, tasks_to_process)
|
| 577 |
+
|
| 578 |
+
for result in tqdm(
|
| 579 |
+
results_iterator, total=len(tasks_to_process), desc="Processing Videos"
|
| 580 |
+
):
|
| 581 |
+
if result:
|
| 582 |
+
if "error" in result:
|
| 583 |
+
with open(error_log_file, "a", encoding="utf-8") as f:
|
| 584 |
+
f.write(
|
| 585 |
+
f"Error on item {result.get('key', 'N/A')}:\n Error: {result['error']}\n---\n"
|
| 586 |
+
)
|
| 587 |
+
all_results.append(result)
|
| 588 |
+
|
| 589 |
+
# Save results every 10 videos to prevent data loss from interruptions
|
| 590 |
+
if len(all_results) % 10 == 0:
|
| 591 |
+
save_json_file(all_results, results_output_file)
|
| 592 |
+
|
| 593 |
+
print("\n\nProcessing complete.")
|
| 594 |
+
# Save the final complete results
|
| 595 |
+
save_json_file(all_results, results_output_file)
|
| 596 |
+
print(f"Detailed results saved to: {results_output_file}")
|
| 597 |
+
|
| 598 |
+
# Calculate and save the final metrics
|
| 599 |
+
final_metrics = calculate_metrics(all_results)
|
| 600 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 601 |
+
print(f"\nEvaluation metrics saved to: {metrics_output_file}")
|
| 602 |
+
print(json.dumps(final_metrics, indent=4))
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
if __name__ == "__main__":
|
| 606 |
+
# Before running this script, please ensure you have set the environment variable in your terminal:
|
| 607 |
+
# export ARK_API_KEY="YOUR_VOLCENGINE_ARK_API_KEY"
|
| 608 |
+
main()
|
main_new_agent.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import traceback
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from functools import partial
|
| 10 |
+
from openai import AzureOpenAI, OpenAI
|
| 11 |
+
from volcenginesdkarkruntime import Ark
|
| 12 |
+
import concurrent.futures
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import AutoModel, AutoProcessor
|
| 16 |
+
from torch.nn.functional import cosine_similarity
|
| 17 |
+
# MODIFIED: Added imports for multiprocessing and uuid
|
| 18 |
+
import multiprocessing
|
| 19 |
+
import uuid
|
| 20 |
+
|
| 21 |
+
# --- Configuration for SigLIP Model ---
|
| 22 |
+
# MODIFIED: Updated to the local model path
|
| 23 |
+
SIGLIP_MODEL_ID = (
|
| 24 |
+
"/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# --- MODIFIED: Updated System Prompt explaining the two tools with examples ---
|
| 28 |
+
AGENT_SYSTEM_PROMPT = """
|
| 29 |
+
You are an intelligent AI assistant specialized in video question answering.
|
| 30 |
+
Your task is to answer a multiple-choice question based on a video by strategically retrieving and analyzing its frames.
|
| 31 |
+
|
| 32 |
+
You have two tools to retrieve frames. Both return images directly.
|
| 33 |
+
|
| 34 |
+
1. `get_frames_by_id(frame_ids)`: Retrieves frames using their specific numerical IDs. Use this when the question provides direct temporal clues or when you need to view specific frames identified by another tool.
|
| 35 |
+
* **Example Use Case:** For a question like "What happens at the 1 minute 30 second mark?", you can calculate the approximate frame ID and use this tool to see the visual.
|
| 36 |
+
* **Example Use Case:** For "Describe the action in frame 550.", you would call this tool with `frame_ids=[550]`.
|
| 37 |
+
|
| 38 |
+
2. `get_frames_by_similarity(query)`: Searches the entire video for frames that visually match a text description and returns the top 5 most relevant frames directly. Use this for content-based questions where the timing is unknown.
|
| 39 |
+
* **Example Use Case:** For a question like "What color is the main character's car?", you would use this tool with a query like "the main character's car".
|
| 40 |
+
* **Example Use Case:** For "Find the scene where a band is playing on stage", you would use the query "a band playing on stage".
|
| 41 |
+
|
| 42 |
+
Your strategy must be efficient:
|
| 43 |
+
1. **Analyze the Query:** First, determine if the question is temporal/logical (better for `get_frames_by_id`) or content-based (requires `get_frames_by_similarity`).
|
| 44 |
+
2. **Retrieve & Analyze:** Call the most appropriate tool. Analyze the returned frames to form a hypothesis.
|
| 45 |
+
3. **Iterate:** If you need more information, refine your search query for the similarity tool or calculate new frame IDs for the ID tool and call again.
|
| 46 |
+
4. **Final Answer:** Once you have gathered enough visual evidence, provide your step-by-step reasoning and then the final answer in the specified JSON format. Do not guess.
|
| 47 |
+
|
| 48 |
+
Your output should follow this format exactly:
|
| 49 |
+
<Your step-by-step reasoning here>
|
| 50 |
+
```json
|
| 51 |
+
{"answer": "X"}
|
| 52 |
+
```
|
| 53 |
+
Do not include any other text after the JSON block.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Tool Schemas
|
| 57 |
+
GET_FRAMES_BY_ID_TOOL_SCHEMA = {
|
| 58 |
+
"type": "function",
|
| 59 |
+
"function": {
|
| 60 |
+
"name": "get_frames_by_id",
|
| 61 |
+
"description": "Retrieves specific video frames by their numerical IDs to get visual information.",
|
| 62 |
+
"parameters": {
|
| 63 |
+
"type": "object",
|
| 64 |
+
"properties": {
|
| 65 |
+
"frame_ids": {
|
| 66 |
+
"type": "array",
|
| 67 |
+
"items": {"type": "integer"},
|
| 68 |
+
"description": "A list of up to 10 frame numbers to retrieve.",
|
| 69 |
+
},
|
| 70 |
+
},
|
| 71 |
+
"required": ["frame_ids"],
|
| 72 |
+
},
|
| 73 |
+
},
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA = {
|
| 77 |
+
"type": "function",
|
| 78 |
+
"function": {
|
| 79 |
+
"name": "get_frames_by_similarity",
|
| 80 |
+
"description": "Searches for and retrieves the top 5 most visually relevant frames for a given text query. Use this to locate visual content when frame numbers are unknown.",
|
| 81 |
+
"parameters": {
|
| 82 |
+
"type": "object",
|
| 83 |
+
"properties": {
|
| 84 |
+
"query": {
|
| 85 |
+
"type": "string",
|
| 86 |
+
"description": "A concise text description of the visual content to search for (e.g., 'a person playing piano').",
|
| 87 |
+
},
|
| 88 |
+
},
|
| 89 |
+
"required": ["query"],
|
| 90 |
+
},
|
| 91 |
+
},
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_arguments():
|
| 96 |
+
"""Parse command line arguments."""
|
| 97 |
+
parser = argparse.ArgumentParser(
|
| 98 |
+
description="Agentic Video QA with Hybrid Frame Retrieval"
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--target-model", "-tm", type=str, required=True, help="Model to evaluate."
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--frames-path",
|
| 105 |
+
"-fp",
|
| 106 |
+
type=str,
|
| 107 |
+
required=True,
|
| 108 |
+
help="Base directory for video frames.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--data-file",
|
| 112 |
+
"-df",
|
| 113 |
+
type=str,
|
| 114 |
+
required=True,
|
| 115 |
+
help="Path to the evaluation dataset.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--embeddings-path",
|
| 119 |
+
"-ep",
|
| 120 |
+
type=str,
|
| 121 |
+
required=True,
|
| 122 |
+
help="Directory with pre-computed frame embeddings.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--max-retry-times",
|
| 126 |
+
"-mr",
|
| 127 |
+
type=int,
|
| 128 |
+
default=10,
|
| 129 |
+
help="Max retries for API calls.",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--pool-processes",
|
| 133 |
+
"-pp",
|
| 134 |
+
type=int,
|
| 135 |
+
default=20,
|
| 136 |
+
help="Number of parallel processes.",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument("--base_url", type=str, required=True, help="API endpoint URL.")
|
| 139 |
+
parser.add_argument("--api_key", type=str, required=True, help="API key.")
|
| 140 |
+
return parser.parse_args()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def save_json_file(data, output_file):
|
| 144 |
+
"""Saves data to a JSON file."""
|
| 145 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 146 |
+
json.dump(data, f, indent=4)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def extract_json_from_response(response):
|
| 150 |
+
"""Extracts a JSON object from a model's response string."""
|
| 151 |
+
if not response:
|
| 152 |
+
return None
|
| 153 |
+
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 154 |
+
if match:
|
| 155 |
+
try:
|
| 156 |
+
return json.loads(match.group(1))
|
| 157 |
+
except (json.JSONDecodeError, IndexError):
|
| 158 |
+
return None
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def calculate_metrics(results):
|
| 163 |
+
"""Calculates accuracy and other metrics from evaluation results."""
|
| 164 |
+
valid_results = [r for r in results if "error" not in r]
|
| 165 |
+
total_samples = len(valid_results)
|
| 166 |
+
if total_samples == 0:
|
| 167 |
+
return {
|
| 168 |
+
"total_samples": 0,
|
| 169 |
+
"answered_samples": 0,
|
| 170 |
+
"correct_answers": 0,
|
| 171 |
+
"accuracy": 0.0,
|
| 172 |
+
}
|
| 173 |
+
answered_samples = sum(
|
| 174 |
+
1 for x in valid_results if x.get("model_answer") is not None
|
| 175 |
+
)
|
| 176 |
+
correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
|
| 177 |
+
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
|
| 178 |
+
return {
|
| 179 |
+
"total_samples": total_samples,
|
| 180 |
+
"answered_samples": answered_samples,
|
| 181 |
+
"correct_answers": correct_answers,
|
| 182 |
+
"accuracy": accuracy,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
|
| 187 |
+
"""Makes a single API call with retry logic and tool support."""
|
| 188 |
+
params = {"model": model, "messages": messages, "max_tokens": 4096}
|
| 189 |
+
if tools:
|
| 190 |
+
params["tools"] = tools
|
| 191 |
+
params["tool_choice"] = "auto"
|
| 192 |
+
|
| 193 |
+
for retry in range(max_retry_times):
|
| 194 |
+
try:
|
| 195 |
+
completion = client.chat.completions.create(**params)
|
| 196 |
+
return completion.choices[0].message
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(
|
| 199 |
+
f"API Error for item {item_id}: {str(e)}. Retrying ({retry + 1}/{max_retry_times})..."
|
| 200 |
+
)
|
| 201 |
+
if retry == max_retry_times - 1:
|
| 202 |
+
raise e
|
| 203 |
+
time.sleep(5)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_frames_by_id(frame_ids: list, all_frame_paths: list):
|
| 207 |
+
"""Tool implementation: Retrieves and encodes frames from a list of IDs."""
|
| 208 |
+
retrieved_frames = []
|
| 209 |
+
frame_map = {
|
| 210 |
+
int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p
|
| 211 |
+
for p in all_frame_paths
|
| 212 |
+
if re.search(r"frame_(\d+)\.jpg", os.path.basename(p))
|
| 213 |
+
}
|
| 214 |
+
for fid in frame_ids:
|
| 215 |
+
path = frame_map.get(fid)
|
| 216 |
+
if path and os.path.exists(path):
|
| 217 |
+
b64_image = encode_image(path)
|
| 218 |
+
retrieved_frames.append(
|
| 219 |
+
{
|
| 220 |
+
"type": "image_url",
|
| 221 |
+
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
return retrieved_frames
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# MODIFIED: This function is now the "client" side of the embedding service.
|
| 228 |
+
def get_frames_by_similarity(
|
| 229 |
+
query: str,
|
| 230 |
+
all_frame_paths: list,
|
| 231 |
+
precomputed_data: dict,
|
| 232 |
+
request_queue: multiprocessing.Queue,
|
| 233 |
+
results_dict: dict,
|
| 234 |
+
k: int = 5,
|
| 235 |
+
):
|
| 236 |
+
"""
|
| 237 |
+
Requests a text embedding from the server process, calculates similarity,
|
| 238 |
+
finds top-k frames, and returns them encoded.
|
| 239 |
+
"""
|
| 240 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 241 |
+
frame_filenames = precomputed_data["filenames"]
|
| 242 |
+
frame_embeddings = precomputed_data["embeddings"].to(device)
|
| 243 |
+
|
| 244 |
+
# 1. Send request to the embedding server process
|
| 245 |
+
request_id = str(uuid.uuid4())
|
| 246 |
+
request_queue.put((request_id, query))
|
| 247 |
+
|
| 248 |
+
# 2. Wait for the result
|
| 249 |
+
while request_id not in results_dict:
|
| 250 |
+
time.sleep(0.05)
|
| 251 |
+
query_embedding = results_dict.pop(request_id).to(device)
|
| 252 |
+
|
| 253 |
+
# 3. Perform similarity search with the received embedding
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
similarities = cosine_similarity(query_embedding, frame_embeddings)
|
| 256 |
+
|
| 257 |
+
num_frames_to_select = min(k, len(frame_filenames))
|
| 258 |
+
top_k_indices = (
|
| 259 |
+
torch.topk(similarities, k=num_frames_to_select, dim=-1)
|
| 260 |
+
.indices.cpu()
|
| 261 |
+
.flatten()
|
| 262 |
+
.numpy()
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
top_k_filenames = [frame_filenames[i] for i in top_k_indices]
|
| 266 |
+
top_k_frame_ids = [
|
| 267 |
+
int(re.search(r"frame_(\d+)\.jpg", f).group(1)) for f in top_k_filenames
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
retrieved_frames = get_frames_by_id(top_k_frame_ids, all_frame_paths)
|
| 271 |
+
return retrieved_frames
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def evaluate_single_item_agentic(
|
| 275 |
+
data_item,
|
| 276 |
+
all_frame_paths,
|
| 277 |
+
embeddings_data,
|
| 278 |
+
target_model,
|
| 279 |
+
api_key,
|
| 280 |
+
base_url,
|
| 281 |
+
max_retry_times,
|
| 282 |
+
request_queue, # MODIFIED: Added queue for IPC
|
| 283 |
+
results_dict, # MODIFIED: Added dict for IPC
|
| 284 |
+
):
|
| 285 |
+
"""Evaluates a single item using an agentic loop."""
|
| 286 |
+
if "ark" in base_url:
|
| 287 |
+
client = Ark(base_url=base_url, api_key=api_key)
|
| 288 |
+
elif "aliyun" in base_url or "127.0.0.1" in base_url:
|
| 289 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 290 |
+
else:
|
| 291 |
+
client = AzureOpenAI(
|
| 292 |
+
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
tools = [GET_FRAMES_BY_ID_TOOL_SCHEMA, GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA]
|
| 296 |
+
|
| 297 |
+
get_frames_by_id_with_context = partial(
|
| 298 |
+
get_frames_by_id, all_frame_paths=all_frame_paths
|
| 299 |
+
)
|
| 300 |
+
# MODIFIED: Pass the request queue and results dict to the similarity function
|
| 301 |
+
get_frames_by_similarity_with_context = partial(
|
| 302 |
+
get_frames_by_similarity,
|
| 303 |
+
all_frame_paths=all_frame_paths,
|
| 304 |
+
precomputed_data=embeddings_data,
|
| 305 |
+
request_queue=request_queue,
|
| 306 |
+
results_dict=results_dict,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
available_functions = {
|
| 310 |
+
"get_frames_by_id": get_frames_by_id_with_context,
|
| 311 |
+
"get_frames_by_similarity": get_frames_by_similarity_with_context,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
total_frames = len(all_frame_paths)
|
| 315 |
+
duration = data_item.get("video_info", {}).get("duration_minutes", 0) * 60
|
| 316 |
+
initial_prompt = (
|
| 317 |
+
f"The video has {total_frames} frames (ID 1 to {total_frames}) and is {duration:.0f} seconds long. "
|
| 318 |
+
f"Please answer this question:\n{data_item['question']}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
messages = [
|
| 322 |
+
{"role": "system", "content": AGENT_SYSTEM_PROMPT},
|
| 323 |
+
{"role": "user", "content": initial_prompt},
|
| 324 |
+
]
|
| 325 |
+
response_content = None
|
| 326 |
+
max_tool_calls = 10
|
| 327 |
+
|
| 328 |
+
for _ in range(max_tool_calls):
|
| 329 |
+
response_message = call_single_model(
|
| 330 |
+
client,
|
| 331 |
+
messages,
|
| 332 |
+
target_model,
|
| 333 |
+
data_item["key"],
|
| 334 |
+
max_retry_times,
|
| 335 |
+
tools=tools,
|
| 336 |
+
)
|
| 337 |
+
if response_message is None:
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
messages.append(response_message)
|
| 341 |
+
|
| 342 |
+
if response_message.tool_calls:
|
| 343 |
+
for tool_call in response_message.tool_calls:
|
| 344 |
+
function_name = tool_call.function.name
|
| 345 |
+
function_to_call = available_functions.get(function_name)
|
| 346 |
+
if function_to_call:
|
| 347 |
+
function_args = json.loads(tool_call.function.arguments)
|
| 348 |
+
function_response = function_to_call(**function_args)
|
| 349 |
+
|
| 350 |
+
messages.append(
|
| 351 |
+
{
|
| 352 |
+
"tool_call_id": tool_call.id,
|
| 353 |
+
"role": "tool",
|
| 354 |
+
"name": function_name,
|
| 355 |
+
"content": json.dumps(
|
| 356 |
+
{
|
| 357 |
+
"status": "success",
|
| 358 |
+
"retrieved_frame_count": len(function_response),
|
| 359 |
+
}
|
| 360 |
+
),
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
user_message_with_frames = [
|
| 365 |
+
{
|
| 366 |
+
"type": "text",
|
| 367 |
+
"text": f"Here are the {len(function_response)} frames from your call to `{function_name}`.",
|
| 368 |
+
}
|
| 369 |
+
]
|
| 370 |
+
user_message_with_frames.extend(function_response)
|
| 371 |
+
messages.append(
|
| 372 |
+
{"role": "user", "content": user_message_with_frames}
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
response_content = response_message.content
|
| 376 |
+
break
|
| 377 |
+
|
| 378 |
+
if response_content is None:
|
| 379 |
+
final_prompt = "You have reached the maximum number of tool calls. Provide a final answer based on the information gathered so far."
|
| 380 |
+
messages.append({"role": "user", "content": final_prompt})
|
| 381 |
+
final_response = call_single_model(
|
| 382 |
+
client, messages, target_model, data_item["key"], max_retry_times
|
| 383 |
+
)
|
| 384 |
+
response_content = (
|
| 385 |
+
final_response.content
|
| 386 |
+
if final_response
|
| 387 |
+
else "Could not determine an answer after max tool calls."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
is_correct = False
|
| 391 |
+
model_answer_cleaned = None
|
| 392 |
+
parsed_json = extract_json_from_response(response_content)
|
| 393 |
+
if parsed_json and "answer" in parsed_json:
|
| 394 |
+
model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
|
| 395 |
+
if model_answer_cleaned == data_item["answer"].strip().upper():
|
| 396 |
+
is_correct = True
|
| 397 |
+
|
| 398 |
+
return {
|
| 399 |
+
**data_item,
|
| 400 |
+
"agent_conversation": [
|
| 401 |
+
msg if isinstance(msg, dict) else msg.model_dump() for msg in messages
|
| 402 |
+
],
|
| 403 |
+
"model_reasoning_and_answer": response_content,
|
| 404 |
+
"model_answer": model_answer_cleaned,
|
| 405 |
+
"is_correct": is_correct,
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def encode_image(image_path):
|
| 410 |
+
"""Encodes an image file to a base64 string."""
|
| 411 |
+
with open(image_path, "rb") as image_file:
|
| 412 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# MODIFIED: Function signature updated to accept queues for IPC
|
| 416 |
+
def process_single_data(data_item, args, request_queue, results_dict):
|
| 417 |
+
"""Main processing function for a single video, executed by a worker."""
|
| 418 |
+
item_key = data_item["key"]
|
| 419 |
+
try:
|
| 420 |
+
specific_frames_path = os.path.join(args.frames_path, item_key)
|
| 421 |
+
embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt")
|
| 422 |
+
|
| 423 |
+
if not os.path.isdir(specific_frames_path):
|
| 424 |
+
raise FileNotFoundError(
|
| 425 |
+
f"Frame directory not found: {specific_frames_path}"
|
| 426 |
+
)
|
| 427 |
+
if not os.path.exists(embedding_file):
|
| 428 |
+
raise FileNotFoundError(f"Embedding file not found: {embedding_file}")
|
| 429 |
+
|
| 430 |
+
all_frame_paths = sorted(
|
| 431 |
+
[
|
| 432 |
+
os.path.join(specific_frames_path, f)
|
| 433 |
+
for f in os.listdir(specific_frames_path)
|
| 434 |
+
if f.endswith(".jpg")
|
| 435 |
+
],
|
| 436 |
+
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
|
| 437 |
+
)
|
| 438 |
+
if not all_frame_paths:
|
| 439 |
+
raise FileNotFoundError(f"No frames found for key '{item_key}'")
|
| 440 |
+
|
| 441 |
+
embeddings_data = torch.load(embedding_file, map_location="cpu")
|
| 442 |
+
|
| 443 |
+
# MODIFIED: Pass queues to the evaluation function
|
| 444 |
+
result = evaluate_single_item_agentic(
|
| 445 |
+
data_item,
|
| 446 |
+
all_frame_paths,
|
| 447 |
+
embeddings_data,
|
| 448 |
+
args.target_model,
|
| 449 |
+
args.api_key,
|
| 450 |
+
args.base_url,
|
| 451 |
+
args.max_retry_times,
|
| 452 |
+
request_queue,
|
| 453 |
+
results_dict,
|
| 454 |
+
)
|
| 455 |
+
return result
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}")
|
| 459 |
+
traceback.print_exc()
|
| 460 |
+
return {
|
| 461 |
+
"key": item_key,
|
| 462 |
+
"uid": data_item.get("uid"),
|
| 463 |
+
"error": str(e),
|
| 464 |
+
"traceback": traceback.format_exc(),
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def load_test_data(json_file):
|
| 469 |
+
"""Loads the evaluation data from a JSON file."""
|
| 470 |
+
try:
|
| 471 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 472 |
+
return json.load(f)
|
| 473 |
+
except FileNotFoundError:
|
| 474 |
+
print(f"Error: Data file not found: {json_file}")
|
| 475 |
+
exit(1)
|
| 476 |
+
except json.JSONDecodeError:
|
| 477 |
+
print(f"Error: Malformed JSON in {json_file}")
|
| 478 |
+
exit(1)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# MODIFIED: This new function runs in its own process, handling embedding requests.
|
| 482 |
+
# It now accepts a model_id and loads the model itself.
|
| 483 |
+
def embedding_server_process(model_id, device, request_queue, results_dict):
|
| 484 |
+
"""
|
| 485 |
+
A server process that loads the SigLIP model and continuously fetches
|
| 486 |
+
text queries from a queue, computes their embeddings, and places the
|
| 487 |
+
results in a shared dictionary.
|
| 488 |
+
"""
|
| 489 |
+
print(f"Embedding server started on PID {os.getpid()}...")
|
| 490 |
+
print("Loading SigLIP model in the embedding server process...")
|
| 491 |
+
model = AutoModel.from_pretrained(model_id)
|
| 492 |
+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
|
| 493 |
+
print("SigLIP model loaded in server.")
|
| 494 |
+
|
| 495 |
+
model.to(device)
|
| 496 |
+
model.eval()
|
| 497 |
+
|
| 498 |
+
while True:
|
| 499 |
+
try:
|
| 500 |
+
request_id, text_query = request_queue.get()
|
| 501 |
+
if text_query == "STOP":
|
| 502 |
+
print("Embedding server received stop signal. Shutting down.")
|
| 503 |
+
break
|
| 504 |
+
|
| 505 |
+
with torch.no_grad():
|
| 506 |
+
text_inputs = processor(
|
| 507 |
+
text=[text_query],
|
| 508 |
+
return_tensors="pt",
|
| 509 |
+
padding=True,
|
| 510 |
+
truncation=True,
|
| 511 |
+
).to(device)
|
| 512 |
+
query_embedding = model.get_text_features(**text_inputs)
|
| 513 |
+
# Move embedding to CPU before sharing across processes
|
| 514 |
+
results_dict[request_id] = query_embedding.cpu()
|
| 515 |
+
except Exception as e:
|
| 516 |
+
print(f"Error in embedding server: {e}")
|
| 517 |
+
traceback.print_exc()
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# MODIFIED: The old init_worker is removed.
|
| 521 |
+
def main():
|
| 522 |
+
"""Main function to orchestrate the evaluation framework."""
|
| 523 |
+
args = parse_arguments()
|
| 524 |
+
print("--- Agentic Video QA with Hybrid Retrieval ---")
|
| 525 |
+
print(
|
| 526 |
+
f"Model: {args.target_model}, Data: {args.data_file}, Embeddings: {args.embeddings_path}"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# MODIFIED: Changed start method to 'spawn' for safety with CUDA and on macOS/Windows.
|
| 530 |
+
try:
|
| 531 |
+
multiprocessing.set_start_method("spawn", force=True)
|
| 532 |
+
print("Multiprocessing start method set to 'spawn'.")
|
| 533 |
+
except RuntimeError:
|
| 534 |
+
print("Start method already set.")
|
| 535 |
+
|
| 536 |
+
# MODIFIED: Model is no longer loaded in the main process.
|
| 537 |
+
# It will be loaded in the dedicated embedding_server_process.
|
| 538 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 539 |
+
|
| 540 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 541 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 542 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_agent_hybrid"
|
| 543 |
+
results_output_file = f"{output_prefix}_results.json"
|
| 544 |
+
metrics_output_file = f"{output_prefix}_metrics.json"
|
| 545 |
+
error_log_file = f"{output_prefix}_errors.log"
|
| 546 |
+
|
| 547 |
+
with open(error_log_file, "a", encoding="utf-8") as f:
|
| 548 |
+
f.write(
|
| 549 |
+
f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
all_test_data = load_test_data(args.data_file)
|
| 553 |
+
existing_results = []
|
| 554 |
+
completed_ids = set()
|
| 555 |
+
if os.path.exists(results_output_file):
|
| 556 |
+
try:
|
| 557 |
+
with open(results_output_file, "r", encoding="utf-8") as f:
|
| 558 |
+
existing_results = json.load(f)
|
| 559 |
+
if isinstance(existing_results, list):
|
| 560 |
+
completed_ids = {
|
| 561 |
+
item["uid"] for item in existing_results if "uid" in item
|
| 562 |
+
}
|
| 563 |
+
print(f"Found {len(completed_ids)} completed tasks. Resuming...")
|
| 564 |
+
else:
|
| 565 |
+
existing_results = []
|
| 566 |
+
except (json.JSONDecodeError, IOError):
|
| 567 |
+
existing_results = []
|
| 568 |
+
|
| 569 |
+
tasks_to_process = [
|
| 570 |
+
item for item in all_test_data if item.get("uid") not in completed_ids
|
| 571 |
+
]
|
| 572 |
+
if not tasks_to_process:
|
| 573 |
+
print("All tasks are already completed. Calculating final metrics.")
|
| 574 |
+
else:
|
| 575 |
+
print(
|
| 576 |
+
f"Total: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}."
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
all_results = list(existing_results)
|
| 580 |
+
|
| 581 |
+
if tasks_to_process:
|
| 582 |
+
# MODIFIED: Set up Manager, Queues, and the embedding server process
|
| 583 |
+
with multiprocessing.Manager() as manager:
|
| 584 |
+
request_queue = manager.Queue()
|
| 585 |
+
results_dict = manager.dict()
|
| 586 |
+
|
| 587 |
+
# MODIFIED: Start the dedicated embedding server process, passing the model ID.
|
| 588 |
+
embedding_server = multiprocessing.Process(
|
| 589 |
+
target=embedding_server_process,
|
| 590 |
+
args=(
|
| 591 |
+
SIGLIP_MODEL_ID,
|
| 592 |
+
device,
|
| 593 |
+
request_queue,
|
| 594 |
+
results_dict,
|
| 595 |
+
),
|
| 596 |
+
)
|
| 597 |
+
embedding_server.start()
|
| 598 |
+
|
| 599 |
+
# MODIFIED: The ProcessPoolExecutor no longer needs an initializer for the model
|
| 600 |
+
with concurrent.futures.ProcessPoolExecutor(
|
| 601 |
+
max_workers=args.pool_processes
|
| 602 |
+
) as executor:
|
| 603 |
+
# MODIFIED: Pass the queues to each worker via partial
|
| 604 |
+
func = partial(
|
| 605 |
+
process_single_data,
|
| 606 |
+
args=args,
|
| 607 |
+
request_queue=request_queue,
|
| 608 |
+
results_dict=results_dict,
|
| 609 |
+
)
|
| 610 |
+
results_iterator = executor.map(func, tasks_to_process)
|
| 611 |
+
for result in tqdm(
|
| 612 |
+
results_iterator,
|
| 613 |
+
total=len(tasks_to_process),
|
| 614 |
+
desc="Processing Videos",
|
| 615 |
+
):
|
| 616 |
+
if result:
|
| 617 |
+
if "error" in result:
|
| 618 |
+
with open(error_log_file, "a", encoding="utf-8") as f:
|
| 619 |
+
f.write(
|
| 620 |
+
f"Error on key {result.get('key', 'N/A')}:\n Error: {result['error']}\n Traceback: {result['traceback']}\n---\n"
|
| 621 |
+
)
|
| 622 |
+
all_results.append(result)
|
| 623 |
+
if len(all_results) % 10 == 0:
|
| 624 |
+
save_json_file(all_results, results_output_file)
|
| 625 |
+
|
| 626 |
+
# MODIFIED: Gracefully shut down the embedding server
|
| 627 |
+
print("All tasks processed. Sending stop signal to embedding server.")
|
| 628 |
+
request_queue.put((None, "STOP"))
|
| 629 |
+
embedding_server.join()
|
| 630 |
+
|
| 631 |
+
print("\n\nProcessing complete.")
|
| 632 |
+
save_json_file(all_results, results_output_file)
|
| 633 |
+
print(f"Detailed results saved to: {results_output_file}")
|
| 634 |
+
|
| 635 |
+
final_metrics = calculate_metrics(all_results)
|
| 636 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 637 |
+
print(f"\nMetrics saved to: {metrics_output_file}")
|
| 638 |
+
print(json.dumps(final_metrics, indent=4))
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
if __name__ == "__main__":
|
| 642 |
+
main()
|
| 643 |
+
|
main_uniform_sampling.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from functools import partial
|
| 9 |
+
from openai import AzureOpenAI, OpenAI
|
| 10 |
+
from volcenginesdkarkruntime import Ark
|
| 11 |
+
from multiprocessing import Pool, Manager, Lock
|
| 12 |
+
|
| 13 |
+
# New prompt template for multiple-choice questions with reasoning
|
| 14 |
+
REASONING_MULTIPLE_CHOICE_TEMPLATE = """
|
| 15 |
+
You are an AI assistant evaluating video frames to answer a multiple-choice question.
|
| 16 |
+
The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D).
|
| 17 |
+
|
| 18 |
+
First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion.
|
| 19 |
+
After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'.
|
| 20 |
+
|
| 21 |
+
Your output should follow this format exactly:
|
| 22 |
+
<Your step-by-step reasoning here>
|
| 23 |
+
```json
|
| 24 |
+
{"answer": "A"}
|
| 25 |
+
```
|
| 26 |
+
Do not include any other text after the JSON block.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_arguments():
|
| 31 |
+
"""
|
| 32 |
+
Parse command line arguments for evaluation configuration.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
argparse.Namespace: Parsed command line arguments
|
| 36 |
+
"""
|
| 37 |
+
parser = argparse.ArgumentParser(description="Video QA Evaluation Framework")
|
| 38 |
+
|
| 39 |
+
# Model configuration
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--target-model",
|
| 42 |
+
"-tm",
|
| 43 |
+
type=str,
|
| 44 |
+
required=True,
|
| 45 |
+
help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Data configuration
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--frame-num",
|
| 51 |
+
"-fn",
|
| 52 |
+
type=int,
|
| 53 |
+
default=32,
|
| 54 |
+
help="Number of frames to uniformly sample from each video (default: 32)",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--frames-path",
|
| 58 |
+
"-fp",
|
| 59 |
+
type=str,
|
| 60 |
+
required=True,
|
| 61 |
+
help="Absolute path to the base directory containing video frame folders.",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--data-file",
|
| 65 |
+
"-df",
|
| 66 |
+
type=str,
|
| 67 |
+
required=True,
|
| 68 |
+
help="Absolute path to the JSON file containing the evaluation dataset.",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Processing configuration
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--max-retry-times",
|
| 74 |
+
"-mr",
|
| 75 |
+
type=int,
|
| 76 |
+
default=10,
|
| 77 |
+
help="Maximum number of retries for API calls (default: 10)",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--pool-processes",
|
| 81 |
+
"-pp",
|
| 82 |
+
type=int,
|
| 83 |
+
default=20,
|
| 84 |
+
help="Number of parallel processes for evaluation (default: 20)",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# API configuration
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--base_url", type=str, required=True, help="Azure OpenAI endpoint URL."
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--api_key", type=str, required=True, help="Azure OpenAI API key."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return parser.parse_args()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_json_file(data, output_file):
|
| 99 |
+
"""
|
| 100 |
+
Save data to a JSON file.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
data (dict or list): Data to be saved.
|
| 104 |
+
output_file (str): Path to the output file.
|
| 105 |
+
"""
|
| 106 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 107 |
+
json.dump(data, f, indent=4)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def extract_json_from_response(response):
|
| 111 |
+
"""
|
| 112 |
+
Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
response (str): The raw response string from the model.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
dict or None: Parsed JSON object or None if no valid JSON block is found.
|
| 119 |
+
"""
|
| 120 |
+
if not response:
|
| 121 |
+
return None
|
| 122 |
+
try:
|
| 123 |
+
# Regex to find the content inside ```json ... ```
|
| 124 |
+
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 125 |
+
if match:
|
| 126 |
+
json_str = match.group(1)
|
| 127 |
+
return json.loads(json_str)
|
| 128 |
+
return None
|
| 129 |
+
except (json.JSONDecodeError, IndexError):
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def calculate_metrics(results):
|
| 134 |
+
"""
|
| 135 |
+
Calculate evaluation metrics from the results.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
results (list): List of results with 'is_correct' field.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
dict: Dictionary containing calculated metrics.
|
| 142 |
+
"""
|
| 143 |
+
total_samples = len(results)
|
| 144 |
+
if total_samples == 0:
|
| 145 |
+
return {
|
| 146 |
+
"total_samples": 0,
|
| 147 |
+
"answered_samples": 0,
|
| 148 |
+
"correct_answers": 0,
|
| 149 |
+
"accuracy": 0.0,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
answered_samples = sum(1 for x in results if x.get("model_answer") is not None)
|
| 153 |
+
correct_answers = sum(1 for x in results if x.get("is_correct"))
|
| 154 |
+
|
| 155 |
+
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
|
| 156 |
+
|
| 157 |
+
metrics = {
|
| 158 |
+
"total_samples": total_samples,
|
| 159 |
+
"answered_samples": answered_samples,
|
| 160 |
+
"correct_answers": correct_answers,
|
| 161 |
+
"accuracy": accuracy,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
return metrics
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def call_single_model(client, messages, model, item_id, max_retry_times):
|
| 168 |
+
"""
|
| 169 |
+
Make a single API call to the specified model with retry logic.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
client: OpenAI client instance.
|
| 173 |
+
messages (list): List of messages for the API call.
|
| 174 |
+
model (str): Model name to use.
|
| 175 |
+
item_id (str): ID of the item being processed (for error logging).
|
| 176 |
+
max_retry_times (int): Maximum number of retries.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
str or None: Model response or None if all retries failed.
|
| 180 |
+
"""
|
| 181 |
+
if "doubao" in model:
|
| 182 |
+
max_tokens = 32768
|
| 183 |
+
else:
|
| 184 |
+
max_tokens = 65535
|
| 185 |
+
retry_times = 0
|
| 186 |
+
while retry_times < max_retry_times:
|
| 187 |
+
try:
|
| 188 |
+
# Set max_tokens to a larger value to allow for reasoning
|
| 189 |
+
completion = client.chat.completions.create(
|
| 190 |
+
model=model, messages=messages, max_tokens=max_tokens
|
| 191 |
+
)
|
| 192 |
+
return completion.choices[0].message.content
|
| 193 |
+
except Exception as e:
|
| 194 |
+
retry_times += 1
|
| 195 |
+
print(
|
| 196 |
+
f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
|
| 197 |
+
)
|
| 198 |
+
if retry_times == max_retry_times:
|
| 199 |
+
error_log_file = f"error_log_{model.replace('/', '_')}.txt"
|
| 200 |
+
with open(error_log_file, "a") as f:
|
| 201 |
+
f.write(
|
| 202 |
+
f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n"
|
| 203 |
+
)
|
| 204 |
+
return None
|
| 205 |
+
time.sleep(5) # Wait before retrying
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def evaluate_single_item(
|
| 209 |
+
data_item, frames, target_model, api_key, base_url, max_retry_times
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Evaluate a single data item using the target model and perform exact match.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
data_item (dict): Dictionary containing question and answer data.
|
| 216 |
+
frames (list): List of encoded video frames.
|
| 217 |
+
target_model (str): Model to be evaluated.
|
| 218 |
+
api_key (str): API key.
|
| 219 |
+
base_url (str): API base URL.
|
| 220 |
+
max_retry_times (int): Maximum number of retries.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
dict: Evaluation result.
|
| 224 |
+
"""
|
| 225 |
+
if "ark" in base_url:
|
| 226 |
+
client = Ark(
|
| 227 |
+
base_url=base_url,
|
| 228 |
+
api_key=api_key,
|
| 229 |
+
)
|
| 230 |
+
elif "aliyun" in base_url or "127.0.0.1" in base_url:
|
| 231 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 232 |
+
else:
|
| 233 |
+
client = AzureOpenAI(
|
| 234 |
+
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Construct messages for the model using the new template
|
| 238 |
+
messages = [
|
| 239 |
+
{"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE},
|
| 240 |
+
{
|
| 241 |
+
"role": "user",
|
| 242 |
+
"content": [
|
| 243 |
+
{"type": "text", "text": "Here are the video frames:"},
|
| 244 |
+
*frames,
|
| 245 |
+
{"type": "text", "text": f"Question: {data_item['question']}"},
|
| 246 |
+
],
|
| 247 |
+
},
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
response = call_single_model(
|
| 251 |
+
client, messages, target_model, data_item["key"], max_retry_times
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
is_correct = False
|
| 255 |
+
model_answer_cleaned = None
|
| 256 |
+
parsed_json = None
|
| 257 |
+
|
| 258 |
+
if response:
|
| 259 |
+
parsed_json = extract_json_from_response(response)
|
| 260 |
+
if parsed_json and "answer" in parsed_json:
|
| 261 |
+
model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
|
| 262 |
+
gold_answer = data_item["answer"].strip().upper()
|
| 263 |
+
if model_answer_cleaned == gold_answer:
|
| 264 |
+
is_correct = True
|
| 265 |
+
|
| 266 |
+
# Create result dictionary
|
| 267 |
+
result = {
|
| 268 |
+
**data_item,
|
| 269 |
+
"model_reasoning_and_answer": response,
|
| 270 |
+
"model_answer_raw": parsed_json.get("answer") if parsed_json else None,
|
| 271 |
+
"model_answer": model_answer_cleaned,
|
| 272 |
+
"is_correct": is_correct,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
return result
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def encode_image(image_path):
|
| 279 |
+
"""
|
| 280 |
+
Encode an image file to base64 string.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
image_path (str): Path to the image file.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
str: Base64 encoded image string.
|
| 287 |
+
"""
|
| 288 |
+
with open(image_path, "rb") as image_file:
|
| 289 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def process_frames(frames_path, frame_num):
|
| 293 |
+
"""
|
| 294 |
+
Process and uniformly sample video frames from a directory, then encode them.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
frames_path (str): Path to the directory containing video frames.
|
| 298 |
+
frame_num (int): The number of frames to sample.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
list: List of encoded frame objects for API consumption.
|
| 302 |
+
"""
|
| 303 |
+
if not os.path.isdir(frames_path):
|
| 304 |
+
print(f"Warning: Frame directory not found at {frames_path}")
|
| 305 |
+
return []
|
| 306 |
+
|
| 307 |
+
frame_files = [
|
| 308 |
+
f
|
| 309 |
+
for f in os.listdir(frames_path)
|
| 310 |
+
if f.startswith("frame_") and f.endswith(".jpg")
|
| 311 |
+
]
|
| 312 |
+
# Sort frames numerically based on the ID in frame_{id}.jpg
|
| 313 |
+
frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
| 314 |
+
|
| 315 |
+
frame_path_list = [os.path.join(frames_path, f) for f in frame_files]
|
| 316 |
+
total_frames = len(frame_path_list)
|
| 317 |
+
|
| 318 |
+
if total_frames == 0:
|
| 319 |
+
return []
|
| 320 |
+
|
| 321 |
+
# Uniformly sample frame paths
|
| 322 |
+
if total_frames > frame_num:
|
| 323 |
+
indices = [int(i * total_frames / frame_num) for i in range(frame_num)]
|
| 324 |
+
sampled_paths = [frame_path_list[i] for i in indices]
|
| 325 |
+
else:
|
| 326 |
+
sampled_paths = frame_path_list # Use all frames if fewer than requested
|
| 327 |
+
|
| 328 |
+
# Encode only the sampled frames
|
| 329 |
+
base64_images = [encode_image(path) for path in sampled_paths]
|
| 330 |
+
|
| 331 |
+
# Create frame objects for API payload
|
| 332 |
+
return [
|
| 333 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
|
| 334 |
+
for b64_img in base64_images
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def process_single_data(
|
| 339 |
+
data_item, args, shared_results, progress_counter, total_items, locks
|
| 340 |
+
):
|
| 341 |
+
"""
|
| 342 |
+
Process a single data item in a multiprocessing context.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
data_item (dict): Single data item to process.
|
| 346 |
+
args: Command line arguments.
|
| 347 |
+
shared_results: Shared list for storing results.
|
| 348 |
+
progress_counter: Shared counter for progress tracking.
|
| 349 |
+
total_items (int): Total number of items to process.
|
| 350 |
+
locks (dict): Dictionary of locks for thread-safe operations.
|
| 351 |
+
"""
|
| 352 |
+
item_key = data_item["key"]
|
| 353 |
+
try:
|
| 354 |
+
# Construct path to the specific video's frames folder
|
| 355 |
+
specific_frames_path = os.path.join(args.frames_path, item_key)
|
| 356 |
+
frames = process_frames(specific_frames_path, args.frame_num)
|
| 357 |
+
|
| 358 |
+
if not frames:
|
| 359 |
+
raise FileNotFoundError(
|
| 360 |
+
f"No frames found or processed for key '{item_key}' at path '{specific_frames_path}'"
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
result = evaluate_single_item(
|
| 364 |
+
data_item,
|
| 365 |
+
frames,
|
| 366 |
+
args.target_model,
|
| 367 |
+
args.api_key,
|
| 368 |
+
args.base_url,
|
| 369 |
+
args.max_retry_times,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
if result is not None:
|
| 373 |
+
with locks["results"]:
|
| 374 |
+
shared_results.append(result)
|
| 375 |
+
# Define output file names inside the worker
|
| 376 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[
|
| 377 |
+
0
|
| 378 |
+
]
|
| 379 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 380 |
+
output_prefix = (
|
| 381 |
+
f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames"
|
| 382 |
+
)
|
| 383 |
+
results_output_file = f"{output_prefix}_results.json"
|
| 384 |
+
# Save the entire updated list of results after each case is processed
|
| 385 |
+
save_json_file(list(shared_results), results_output_file)
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f"Error processing video key {item_key}: {str(e)}")
|
| 389 |
+
with locks["file"]:
|
| 390 |
+
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
|
| 391 |
+
with open(error_log_file, "a") as f:
|
| 392 |
+
f.write(f"Critical error processing video key {item_key}: {str(e)}\n")
|
| 393 |
+
finally:
|
| 394 |
+
# Always update progress counter
|
| 395 |
+
with locks["counter"]:
|
| 396 |
+
progress_counter.value += 1
|
| 397 |
+
print(
|
| 398 |
+
f"\rProcessed: {progress_counter.value}/{total_items} videos...",
|
| 399 |
+
end="",
|
| 400 |
+
flush=True,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def load_test_data(json_file):
|
| 405 |
+
"""
|
| 406 |
+
Load test data from a JSON file.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
json_file (str): Path to the JSON file.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
list: List of test data items.
|
| 413 |
+
"""
|
| 414 |
+
try:
|
| 415 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 416 |
+
return json.load(f)
|
| 417 |
+
except FileNotFoundError:
|
| 418 |
+
print(f"Error: Data file not found at {json_file}")
|
| 419 |
+
exit(1)
|
| 420 |
+
except json.JSONDecodeError:
|
| 421 |
+
print(f"Error: Could not decode JSON from {json_file}")
|
| 422 |
+
exit(1)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def main():
|
| 426 |
+
"""
|
| 427 |
+
Main function to run the video QA evaluation framework.
|
| 428 |
+
"""
|
| 429 |
+
args = parse_arguments()
|
| 430 |
+
|
| 431 |
+
print("--- Evaluation Configuration ---")
|
| 432 |
+
print(f"Target Model: {args.target_model}")
|
| 433 |
+
print(f"Frames to Sample: {args.frame_num}")
|
| 434 |
+
print(f"Frames Base Path: {args.frames_path}")
|
| 435 |
+
print(f"Data File: {args.data_file}")
|
| 436 |
+
print(f"Parallel Processes: {args.pool_processes}")
|
| 437 |
+
print("---------------------------------")
|
| 438 |
+
|
| 439 |
+
# Initialize error log file
|
| 440 |
+
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt"
|
| 441 |
+
with open(error_log_file, "w") as f:
|
| 442 |
+
f.write(
|
| 443 |
+
f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Define output file names
|
| 447 |
+
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]
|
| 448 |
+
model_name_safe = args.target_model.replace("/", "_")
|
| 449 |
+
output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames"
|
| 450 |
+
|
| 451 |
+
results_output_file = f"{output_prefix}_results.json"
|
| 452 |
+
metrics_output_file = f"{output_prefix}_metrics.json"
|
| 453 |
+
|
| 454 |
+
# Load data
|
| 455 |
+
test_data = load_test_data(args.data_file)
|
| 456 |
+
total_videos = len(test_data)
|
| 457 |
+
print(f"\nLoaded {total_videos} videos to process.")
|
| 458 |
+
|
| 459 |
+
# Set up multiprocessing
|
| 460 |
+
with Manager() as manager:
|
| 461 |
+
shared_results = manager.list()
|
| 462 |
+
progress_counter = manager.Value("i", 0)
|
| 463 |
+
|
| 464 |
+
locks = {
|
| 465 |
+
"results": manager.Lock(),
|
| 466 |
+
"file": manager.Lock(),
|
| 467 |
+
"counter": manager.Lock(),
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
# Create a partial function with fixed arguments for the worker pool
|
| 471 |
+
process_func = partial(
|
| 472 |
+
process_single_data,
|
| 473 |
+
args=args,
|
| 474 |
+
shared_results=shared_results,
|
| 475 |
+
progress_counter=progress_counter,
|
| 476 |
+
total_items=total_videos,
|
| 477 |
+
locks=locks,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Run processing in parallel
|
| 481 |
+
with Pool(processes=args.pool_processes) as pool:
|
| 482 |
+
pool.map(process_func, test_data)
|
| 483 |
+
|
| 484 |
+
# Convert shared list to a regular list for final processing
|
| 485 |
+
all_results = list(shared_results)
|
| 486 |
+
|
| 487 |
+
print(f"\n\nProcessing complete for model: {args.target_model}")
|
| 488 |
+
|
| 489 |
+
# Calculate and save final metrics
|
| 490 |
+
final_metrics = calculate_metrics(all_results)
|
| 491 |
+
save_json_file(final_metrics, metrics_output_file)
|
| 492 |
+
print(f"\nMetrics saved to: {metrics_output_file}")
|
| 493 |
+
print(json.dumps(final_metrics, indent=4))
|
| 494 |
+
|
| 495 |
+
# Save final results
|
| 496 |
+
save_json_file(all_results, results_output_file)
|
| 497 |
+
print(f"Detailed results saved to: {results_output_file}")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
main()
|
offline_compute_similarity.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
import tempfile
|
| 6 |
+
import glob
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoModel, AutoProcessor
|
| 9 |
+
from torch.nn.functional import cosine_similarity
|
| 10 |
+
import torch.multiprocessing as mp
|
| 11 |
+
|
| 12 |
+
# --- 配置 ---
|
| 13 |
+
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_arguments():
|
| 17 |
+
"""解析命令行参数"""
|
| 18 |
+
parser = argparse.ArgumentParser(
|
| 19 |
+
description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--data-file",
|
| 23 |
+
"-df",
|
| 24 |
+
type=str,
|
| 25 |
+
required=True,
|
| 26 |
+
help="包含评估数据集的JSON文件的绝对路径。",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--embeddings-path",
|
| 30 |
+
"-ep",
|
| 31 |
+
type=str,
|
| 32 |
+
required=True,
|
| 33 |
+
help="包含预计算嵌入.pt文件的目录的绝对路径。",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--output-file",
|
| 37 |
+
"-o",
|
| 38 |
+
type=str,
|
| 39 |
+
required=True,
|
| 40 |
+
help="用于保存最终相似度分数的JSON文件路径。",
|
| 41 |
+
)
|
| 42 |
+
return parser.parse_args()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_test_data(json_file):
|
| 46 |
+
"""从JSON文件加载测试数据"""
|
| 47 |
+
try:
|
| 48 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 49 |
+
return json.load(f)
|
| 50 |
+
except FileNotFoundError:
|
| 51 |
+
print(f"错误: 在 {json_file} 未找到数据文件")
|
| 52 |
+
exit(1)
|
| 53 |
+
except json.JSONDecodeError:
|
| 54 |
+
print(f"错误: 无法从 {json_file} 解码JSON")
|
| 55 |
+
exit(1)
|
| 56 |
+
return []
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_json_file(data, output_file):
|
| 60 |
+
"""将数据保存到JSON文件"""
|
| 61 |
+
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 62 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 63 |
+
json.dump(data, f, indent=4)
|
| 64 |
+
print(f"\n成功将最终相似度结果保存到 {output_file}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def process_question_chunk(args_tuple):
|
| 68 |
+
"""
|
| 69 |
+
工作函数,用于处理一批问题并增量保存结果。
|
| 70 |
+
"""
|
| 71 |
+
data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple
|
| 72 |
+
device = f"cuda:{gpu_id}"
|
| 73 |
+
|
| 74 |
+
# 为此工作进程定义一个唯一的临时输出文件
|
| 75 |
+
temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl")
|
| 76 |
+
|
| 77 |
+
# 只需要模型来计算文本特征
|
| 78 |
+
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
|
| 79 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
|
| 80 |
+
|
| 81 |
+
progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
|
| 82 |
+
|
| 83 |
+
# 缓存已加载的嵌入以避免重复IO
|
| 84 |
+
embedding_cache = {}
|
| 85 |
+
|
| 86 |
+
with open(temp_output_file, "a", encoding="utf-8") as f_out:
|
| 87 |
+
for data_item in progress_bar:
|
| 88 |
+
item_key = data_item["key"]
|
| 89 |
+
question_key = data_item["uid"]
|
| 90 |
+
question = data_item["question"].split("\n(A)")[0]
|
| 91 |
+
|
| 92 |
+
embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt")
|
| 93 |
+
if not os.path.exists(embedding_file_path):
|
| 94 |
+
progress_bar.write(
|
| 95 |
+
f"Warning: Embedding file not found for '{item_key}', skipping."
|
| 96 |
+
)
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
# 从缓存或文件中加载嵌入
|
| 101 |
+
if item_key not in embedding_cache:
|
| 102 |
+
loaded_data = torch.load(embedding_file_path, map_location="cpu")
|
| 103 |
+
embedding_cache[item_key] = {
|
| 104 |
+
"filenames": loaded_data["filenames"],
|
| 105 |
+
"embeddings": loaded_data["embeddings"],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
frame_files = embedding_cache[item_key]["filenames"]
|
| 109 |
+
frame_embeddings = embedding_cache[item_key]["embeddings"].to(device)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
# --- 文本嵌入 ---
|
| 113 |
+
text_inputs = processor(
|
| 114 |
+
text=[question],
|
| 115 |
+
return_tensors="pt",
|
| 116 |
+
padding=True,
|
| 117 |
+
truncation=True,
|
| 118 |
+
).to(device)
|
| 119 |
+
question_embedding = model.get_text_features(**text_inputs)
|
| 120 |
+
|
| 121 |
+
# --- 相似度计算和排序 ---
|
| 122 |
+
similarities = cosine_similarity(
|
| 123 |
+
question_embedding, frame_embeddings
|
| 124 |
+
)
|
| 125 |
+
scored_frames = sorted(
|
| 126 |
+
zip(frame_files, similarities.cpu().numpy()),
|
| 127 |
+
key=lambda x: x[1],
|
| 128 |
+
reverse=True,
|
| 129 |
+
)
|
| 130 |
+
sorted_frame_filenames = [frame[0] for frame in scored_frames]
|
| 131 |
+
|
| 132 |
+
single_result = {question_key: sorted_frame_filenames}
|
| 133 |
+
f_out.write(json.dumps(single_result) + "\n")
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def main():
|
| 140 |
+
"""主函数,用于协调多GPU处理"""
|
| 141 |
+
args = parse_arguments()
|
| 142 |
+
|
| 143 |
+
num_gpus = torch.cuda.device_count()
|
| 144 |
+
if num_gpus == 0:
|
| 145 |
+
print("错误: 未找到启用CUDA的GPU。正在退出。")
|
| 146 |
+
exit(1)
|
| 147 |
+
|
| 148 |
+
print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...")
|
| 149 |
+
|
| 150 |
+
test_data = load_test_data(args.data_file)
|
| 151 |
+
if not test_data:
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
chunk_size = (len(test_data) + num_gpus - 1) // num_gpus
|
| 155 |
+
data_chunks = [
|
| 156 |
+
test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size)
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 160 |
+
print(f"使用临时目录存储中间结果: {temp_dir}")
|
| 161 |
+
|
| 162 |
+
process_args = [
|
| 163 |
+
(data_chunks[i], args.embeddings_path, i, temp_dir)
|
| 164 |
+
for i in range(len(data_chunks))
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
with mp.Pool(processes=num_gpus) as pool:
|
| 168 |
+
pool.map(process_question_chunk, process_args)
|
| 169 |
+
|
| 170 |
+
# --- 从临时文件合并和保存最终结果 ---
|
| 171 |
+
print("\n\n所有GPU进程已完成。正在从临时文件合并结果...")
|
| 172 |
+
final_similarity_results = {}
|
| 173 |
+
|
| 174 |
+
temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl"))
|
| 175 |
+
|
| 176 |
+
for temp_file in tqdm(temp_files, desc="合并文件"):
|
| 177 |
+
with open(temp_file, "r", encoding="utf-8") as f:
|
| 178 |
+
for line in f:
|
| 179 |
+
try:
|
| 180 |
+
data = json.loads(line)
|
| 181 |
+
final_similarity_results.update(data)
|
| 182 |
+
except json.JSONDecodeError:
|
| 183 |
+
print(f"警告: 跳过 {temp_file} 中的损坏行")
|
| 184 |
+
|
| 185 |
+
save_json_file(final_similarity_results, args.output_file)
|
| 186 |
+
print(f"总共处理的项目数: {len(final_similarity_results)}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
mp.set_start_method("spawn", force=True)
|
| 191 |
+
main()
|
utils/count_frames.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
with open(
|
| 5 |
+
"/mnt/bn/ziyang-storage-cloudnative-hl/VideoSimpleQA/o4-mini-2025-04-16_lvbench_agent_results.json",
|
| 6 |
+
"r",
|
| 7 |
+
) as f:
|
| 8 |
+
data = json.load(f)
|
| 9 |
+
|
| 10 |
+
# print(data[0])
|
| 11 |
+
# import pdb; pdb.set_trace()
|
| 12 |
+
# print(data[1])
|
| 13 |
+
|
| 14 |
+
frames_count = []
|
| 15 |
+
# for d in data:
|
| 16 |
+
# num_frames = 0
|
| 17 |
+
# for turn in d["agent_conversation"]:
|
| 18 |
+
# if turn["role"] == "assistant" and turn["tool_calls"]:
|
| 19 |
+
# for tool_call in turn["tool_calls"]:
|
| 20 |
+
# frame_ids = json.loads(tool_call["function"]["arguments"])["frame_ids"]
|
| 21 |
+
# num_frames += len(frame_ids)
|
| 22 |
+
# frames_count.append(num_frames)
|
| 23 |
+
|
| 24 |
+
for d in data:
|
| 25 |
+
num_frames = 0
|
| 26 |
+
for turn in d["agent_conversation"]:
|
| 27 |
+
if turn["role"] == "user" and type(turn['content']) == list:
|
| 28 |
+
for item in turn['content']:
|
| 29 |
+
if item['type'] == 'image_url':
|
| 30 |
+
num_frames += 1
|
| 31 |
+
frames_count.append(num_frames)
|
| 32 |
+
|
| 33 |
+
print(f"mean frames: {sum(frames_count) / len(frames_count)}")
|
| 34 |
+
print(f"max frames: {max(frames_count)}")
|
| 35 |
+
print(f"min frames: {min(frames_count)}")
|