File size: 6,657 Bytes
608eb1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import json
import torch
import argparse
import tempfile
import glob
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor
from torch.nn.functional import cosine_similarity
import torch.multiprocessing as mp
# --- 配置 ---
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。"
)
parser.add_argument(
"--data-file",
"-df",
type=str,
required=True,
help="包含评估数据集的JSON文件的绝对路径。",
)
parser.add_argument(
"--embeddings-path",
"-ep",
type=str,
required=True,
help="包含预计算嵌入.pt文件的目录的绝对路径。",
)
parser.add_argument(
"--output-file",
"-o",
type=str,
required=True,
help="用于保存最终相似度分数的JSON文件路径。",
)
return parser.parse_args()
def load_test_data(json_file):
"""从JSON文件加载测试数据"""
try:
with open(json_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"错误: 在 {json_file} 未找到数据文件")
exit(1)
except json.JSONDecodeError:
print(f"错误: 无法从 {json_file} 解码JSON")
exit(1)
return []
def save_json_file(data, output_file):
"""将数据保存到JSON文件"""
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
print(f"\n成功将最终相似度结果保存到 {output_file}")
def process_question_chunk(args_tuple):
"""
工作函数,用于处理一批问题并增量保存结果。
"""
data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple
device = f"cuda:{gpu_id}"
# 为此工作进程定义一个唯一的临时输出文件
temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl")
# 只需要模型来计算文本特征
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
# 缓存已加载的嵌入以避免重复IO
embedding_cache = {}
with open(temp_output_file, "a", encoding="utf-8") as f_out:
for data_item in progress_bar:
item_key = data_item["key"]
question_key = data_item["uid"]
question = data_item["question"].split("\n(A)")[0]
embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt")
if not os.path.exists(embedding_file_path):
progress_bar.write(
f"Warning: Embedding file not found for '{item_key}', skipping."
)
continue
try:
# 从缓存或文件中加载嵌入
if item_key not in embedding_cache:
loaded_data = torch.load(embedding_file_path, map_location="cpu")
embedding_cache[item_key] = {
"filenames": loaded_data["filenames"],
"embeddings": loaded_data["embeddings"],
}
frame_files = embedding_cache[item_key]["filenames"]
frame_embeddings = embedding_cache[item_key]["embeddings"].to(device)
with torch.no_grad():
# --- 文本嵌入 ---
text_inputs = processor(
text=[question],
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
question_embedding = model.get_text_features(**text_inputs)
# --- 相似度计算和排序 ---
similarities = cosine_similarity(
question_embedding, frame_embeddings
)
scored_frames = sorted(
zip(frame_files, similarities.cpu().numpy()),
key=lambda x: x[1],
reverse=True,
)
sorted_frame_filenames = [frame[0] for frame in scored_frames]
single_result = {question_key: sorted_frame_filenames}
f_out.write(json.dumps(single_result) + "\n")
except Exception as e:
progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}")
def main():
"""主函数,用于协调多GPU处理"""
args = parse_arguments()
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print("错误: 未找到启用CUDA的GPU。正在退出。")
exit(1)
print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...")
test_data = load_test_data(args.data_file)
if not test_data:
return
chunk_size = (len(test_data) + num_gpus - 1) // num_gpus
data_chunks = [
test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size)
]
with tempfile.TemporaryDirectory() as temp_dir:
print(f"使用临时目录存储中间结果: {temp_dir}")
process_args = [
(data_chunks[i], args.embeddings_path, i, temp_dir)
for i in range(len(data_chunks))
]
with mp.Pool(processes=num_gpus) as pool:
pool.map(process_question_chunk, process_args)
# --- 从临时文件合并和保存最终结果 ---
print("\n\n所有GPU进程已完成。正在从临时文件合并结果...")
final_similarity_results = {}
temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl"))
for temp_file in tqdm(temp_files, desc="合并文件"):
with open(temp_file, "r", encoding="utf-8") as f:
for line in f:
try:
data = json.loads(line)
final_similarity_results.update(data)
except json.JSONDecodeError:
print(f"警告: 跳过 {temp_file} 中的损坏行")
save_json_file(final_similarity_results, args.output_file)
print(f"总共处理的项目数: {len(final_similarity_results)}")
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
main()
|