import os import json import torch import argparse import tempfile import glob from tqdm import tqdm from transformers import AutoModel, AutoProcessor from torch.nn.functional import cosine_similarity import torch.multiprocessing as mp # --- 配置 --- MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser( description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。" ) parser.add_argument( "--data-file", "-df", type=str, required=True, help="包含评估数据集的JSON文件的绝对路径。", ) parser.add_argument( "--embeddings-path", "-ep", type=str, required=True, help="包含预计算嵌入.pt文件的目录的绝对路径。", ) parser.add_argument( "--output-file", "-o", type=str, required=True, help="用于保存最终相似度分数的JSON文件路径。", ) return parser.parse_args() def load_test_data(json_file): """从JSON文件加载测试数据""" try: with open(json_file, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: print(f"错误: 在 {json_file} 未找到数据文件") exit(1) except json.JSONDecodeError: print(f"错误: 无法从 {json_file} 解码JSON") exit(1) return [] def save_json_file(data, output_file): """将数据保存到JSON文件""" # os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) print(f"\n成功将最终相似度结果保存到 {output_file}") def process_question_chunk(args_tuple): """ 工作函数,用于处理一批问题并增量保存结果。 """ data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple device = f"cuda:{gpu_id}" # 为此工作进程定义一个唯一的临时输出文件 temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl") # 只需要模型来计算文本特征 model = AutoModel.from_pretrained(MODEL_ID).to(device).eval() processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True) progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}") # 缓存已加载的嵌入以避免重复IO embedding_cache = {} with open(temp_output_file, "a", encoding="utf-8") as f_out: for data_item in progress_bar: item_key = data_item["key"] question_key = data_item["uid"] question = data_item["question"].split("\n(A)")[0] embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt") if not os.path.exists(embedding_file_path): progress_bar.write( f"Warning: Embedding file not found for '{item_key}', skipping." ) continue try: # 从缓存或文件中加载嵌入 if item_key not in embedding_cache: loaded_data = torch.load(embedding_file_path, map_location="cpu") embedding_cache[item_key] = { "filenames": loaded_data["filenames"], "embeddings": loaded_data["embeddings"], } frame_files = embedding_cache[item_key]["filenames"] frame_embeddings = embedding_cache[item_key]["embeddings"].to(device) with torch.no_grad(): # --- 文本嵌入 --- text_inputs = processor( text=[question], return_tensors="pt", padding=True, truncation=True, ).to(device) question_embedding = model.get_text_features(**text_inputs) # --- 相似度计算和排序 --- similarities = cosine_similarity( question_embedding, frame_embeddings ) scored_frames = sorted( zip(frame_files, similarities.cpu().numpy()), key=lambda x: x[1], reverse=True, ) sorted_frame_filenames = [frame[0] for frame in scored_frames] single_result = {question_key: sorted_frame_filenames} f_out.write(json.dumps(single_result) + "\n") except Exception as e: progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}") def main(): """主函数,用于协调多GPU处理""" args = parse_arguments() num_gpus = torch.cuda.device_count() if num_gpus == 0: print("错误: 未找到启用CUDA的GPU。正在退出。") exit(1) print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...") test_data = load_test_data(args.data_file) if not test_data: return chunk_size = (len(test_data) + num_gpus - 1) // num_gpus data_chunks = [ test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size) ] with tempfile.TemporaryDirectory() as temp_dir: print(f"使用临时目录存储中间结果: {temp_dir}") process_args = [ (data_chunks[i], args.embeddings_path, i, temp_dir) for i in range(len(data_chunks)) ] with mp.Pool(processes=num_gpus) as pool: pool.map(process_question_chunk, process_args) # --- 从临时文件合并和保存最终结果 --- print("\n\n所有GPU进程已完成。正在从临时文件合并结果...") final_similarity_results = {} temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl")) for temp_file in tqdm(temp_files, desc="合并文件"): with open(temp_file, "r", encoding="utf-8") as f: for line in f: try: data = json.loads(line) final_similarity_results.update(data) except json.JSONDecodeError: print(f"警告: 跳过 {temp_file} 中的损坏行") save_json_file(final_similarity_results, args.output_file) print(f"总共处理的项目数: {len(final_similarity_results)}") if __name__ == "__main__": mp.set_start_method("spawn", force=True) main()