VideoSimpleQA / offline_compute_similarity.py
hzy's picture
Initial upload of all project files
608eb1a verified
import os
import json
import torch
import argparse
import tempfile
import glob
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor
from torch.nn.functional import cosine_similarity
import torch.multiprocessing as mp
# --- 配置 ---
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。"
)
parser.add_argument(
"--data-file",
"-df",
type=str,
required=True,
help="包含评估数据集的JSON文件的绝对路径。",
)
parser.add_argument(
"--embeddings-path",
"-ep",
type=str,
required=True,
help="包含预计算嵌入.pt文件的目录的绝对路径。",
)
parser.add_argument(
"--output-file",
"-o",
type=str,
required=True,
help="用于保存最终相似度分数的JSON文件路径。",
)
return parser.parse_args()
def load_test_data(json_file):
"""从JSON文件加载测试数据"""
try:
with open(json_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"错误: 在 {json_file} 未找到数据文件")
exit(1)
except json.JSONDecodeError:
print(f"错误: 无法从 {json_file} 解码JSON")
exit(1)
return []
def save_json_file(data, output_file):
"""将数据保存到JSON文件"""
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
print(f"\n成功将最终相似度结果保存到 {output_file}")
def process_question_chunk(args_tuple):
"""
工作函数,用于处理一批问题并增量保存结果。
"""
data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple
device = f"cuda:{gpu_id}"
# 为此工作进程定义一个唯一的临时输出文件
temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl")
# 只需要模型来计算文本特征
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
# 缓存已加载的嵌入以避免重复IO
embedding_cache = {}
with open(temp_output_file, "a", encoding="utf-8") as f_out:
for data_item in progress_bar:
item_key = data_item["key"]
question_key = data_item["uid"]
question = data_item["question"].split("\n(A)")[0]
embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt")
if not os.path.exists(embedding_file_path):
progress_bar.write(
f"Warning: Embedding file not found for '{item_key}', skipping."
)
continue
try:
# 从缓存或文件中加载嵌入
if item_key not in embedding_cache:
loaded_data = torch.load(embedding_file_path, map_location="cpu")
embedding_cache[item_key] = {
"filenames": loaded_data["filenames"],
"embeddings": loaded_data["embeddings"],
}
frame_files = embedding_cache[item_key]["filenames"]
frame_embeddings = embedding_cache[item_key]["embeddings"].to(device)
with torch.no_grad():
# --- 文本嵌入 ---
text_inputs = processor(
text=[question],
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
question_embedding = model.get_text_features(**text_inputs)
# --- 相似度计算和排序 ---
similarities = cosine_similarity(
question_embedding, frame_embeddings
)
scored_frames = sorted(
zip(frame_files, similarities.cpu().numpy()),
key=lambda x: x[1],
reverse=True,
)
sorted_frame_filenames = [frame[0] for frame in scored_frames]
single_result = {question_key: sorted_frame_filenames}
f_out.write(json.dumps(single_result) + "\n")
except Exception as e:
progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}")
def main():
"""主函数,用于协调多GPU处理"""
args = parse_arguments()
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print("错误: 未找到启用CUDA的GPU。正在退出。")
exit(1)
print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...")
test_data = load_test_data(args.data_file)
if not test_data:
return
chunk_size = (len(test_data) + num_gpus - 1) // num_gpus
data_chunks = [
test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size)
]
with tempfile.TemporaryDirectory() as temp_dir:
print(f"使用临时目录存储中间结果: {temp_dir}")
process_args = [
(data_chunks[i], args.embeddings_path, i, temp_dir)
for i in range(len(data_chunks))
]
with mp.Pool(processes=num_gpus) as pool:
pool.map(process_question_chunk, process_args)
# --- 从临时文件合并和保存最终结果 ---
print("\n\n所有GPU进程已完成。正在从临时文件合并结果...")
final_similarity_results = {}
temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl"))
for temp_file in tqdm(temp_files, desc="合并文件"):
with open(temp_file, "r", encoding="utf-8") as f:
for line in f:
try:
data = json.loads(line)
final_similarity_results.update(data)
except json.JSONDecodeError:
print(f"警告: 跳过 {temp_file} 中的损坏行")
save_json_file(final_similarity_results, args.output_file)
print(f"总共处理的项目数: {len(final_similarity_results)}")
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
main()