from sentence_transformers import SentenceTransformer import torch import torch.nn.functional as F import json device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json" #data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct_fixmusic.json" #data_path = "/home/yuqian_fu/Projects/PSALM/check_text_byname_600_objname_llavaname_correct.json" #最初的一版 with open(data_path, "r") as f: datas = json.load(f) # 定义相似性阈值 SIMILARITY_THRESHOLD = 0.5 # 加载预训练的Sentence-BERT模型 model = SentenceTransformer("all-MiniLM-L6-v2").to(device) #model = SentenceTransformer("all-mpnet-base-v2").to(device) # 使用MPNet模型 model.eval() # 设置模型为评估模式 def get_sbert_embedding(text): """ 使用Sentence-BERT提取文本特征向量 """ with torch.no_grad(): embedding = model.encode(text, convert_to_tensor=True, device=device) return embedding def calculate_cosine_similarity(embedding1, embedding2): """ 计算两个特征向量的余弦相似性 """ similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0)) return similarity.item() # 统计正确样本数 correct_count = 0 num_total = 0 # 遍历数据,计算相似性 similarity_list = [] '''v1:以物体个数为分母''' # for data in datas: # annos = data["first_frame_anns"] # for anno in annos: # num_total += 1 # obj_name = anno["obj_name"] # llava_text = anno["llava_text"] # # 提取特征 # obj_embedding = get_sbert_embedding(obj_name) # llava_embedding = get_sbert_embedding(llava_text) # # 计算相似性 # similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) # print("similarity:", similarity) # similarity_list.append(similarity) # # 判断是否为正确样本 # if similarity > SIMILARITY_THRESHOLD: # correct_count += 1 '''v2:以视频帧数为分母''' for data in datas: annos = data["first_frame_anns"] score = 0 for anno in annos: obj_name = anno["obj_name"] llava_text = anno["llava_text"] # 提取特征 obj_embedding = get_sbert_embedding(obj_name) llava_embedding = get_sbert_embedding(llava_text) # 计算相似性 similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) score += similarity sim_avg = score / len(annos) similarity_list.append(sim_avg) if sim_avg > SIMILARITY_THRESHOLD: correct_count += 1 # 计算正确样本比例 # print("total:", num_total) total_samples = len(datas) # total_samples = num_total accuracy = correct_count / total_samples average_similarity = sum(similarity_list) / total_samples print(f"正确样本数: {correct_count}") print(f"总样本数: {total_samples}") print(f"正确样本比例: {accuracy:.2%}") print(f"平均相似性: {average_similarity:.4f}")