import json from sentence_transformers import SentenceTransformer import torch import torch.nn.functional as F import re import string # json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_instruction.json" # json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_objname_instruction.json" # with open(json_path_1, "r") as fp: # datas_1 = json.load(fp) # with open(json_path_2, "r") as fp: # datas_2 = json.load(fp) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义相似性阈值 SIMILARITY_THRESHOLD = 0.5 # 加载预训练的Sentence-BERT模型 model = SentenceTransformer("all-MiniLM-L6-v2").to(device) model.eval() # 设置模型为评估模式 def extract_object_name(text): parts = text.split("is") if len(parts) > 1: return parts[1].strip() return None def get_sbert_embedding(text): """ 使用Sentence-BERT提取文本特征向量 """ with torch.no_grad(): embedding = model.encode(text, convert_to_tensor=True, device=device) return embedding def calculate_cosine_similarity(embedding1, embedding2): """ 计算两个特征向量的余弦相似性 """ similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0)) return similarity.item() # # 统计正确样本数 # correct_count = 0 # # 遍历数据,计算相似性 # similarity_list = [] # # 总物体数目 # obj_total = 0 # datas_save_1 = [] # datas_save_2 = [] # for data1, data2 in zip(datas_1, datas_2): # score = 0 # annos_1 = data1["first_frame_anns"] # annos_2 = data2["first_frame_anns"] # for anno1, anno2 in zip(annos_1, annos_2): # obj_total += 1 # # 得到llava_text # llava_text = anno1["text"] # llava_text = extract_object_name(llava_text) # raw_lower = llava_text.lower() # result = raw_lower.replace("green", "").strip() # sent = result.translate(str.maketrans('', '', string.punctuation)) # if "a " in sent: # llava_text = sent.replace("a ", "") # elif "an " in sent: # llava_text = sent.replace("an ", "") # print("llava_text:", llava_text) # # 得到obj_name # obj_name = anno2["text"] # obj_name = re.sub(r'_\d+$', '', obj_name) # print("obj_name:", obj_name) # # 提取特征 # obj_embedding = get_sbert_embedding(obj_name) # llava_embedding = get_sbert_embedding(llava_text) # # 计算相似性 # similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) # print("similarity:", similarity) # similarity_list.append(similarity) # score += similarity # # 判断是否为正确样本 # if similarity > SIMILARITY_THRESHOLD: # correct_count += 1 # score_avg = score / len(annos_1) # print("score_avg:", score_avg) # if score_avg > SIMILARITY_THRESHOLD: # datas_save_1.append(data1) # datas_save_2.append(data2) # # 计算正确样本比例 # total_samples_before = len(datas_1) # print("num before:", total_samples_before) # accuracy = correct_count / obj_total # average_similarity = sum(similarity_list) / obj_total # print(f"正确样本数: {correct_count}") # print(f"总样本数: {obj_total}") # print(f"正确样本比例: {accuracy:.2%}") # print(f"平均相似性: {average_similarity:.4f}") # print("after filter num:", len(datas_save_1)) # save_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json" # save_paht_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json" # with open(save_path_1, "w") as fp: # json.dump(datas_save_1, fp) # with open(save_paht_2, "w") as fp: # json.dump(datas_save_2, fp) # json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json" # json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json" # with open(json_path_1, "r") as fp: # datas_1 = json.load(fp) # with open(json_path_2, "r") as fp: # datas_2 = json.load(fp) # # 为方便比较,只提取json文件中的obj_name和llava_text # to_save = [] # for data1, data2 in zip(datas_1, datas_2): # annos_1 = data1["first_frame_anns"] # annos_2 = data2["first_frame_anns"] # for anno1, anno2 in zip(annos_1, annos_2): # # 得到llava_text # llava_text = anno1["text"] # llava_text = extract_object_name(llava_text) # raw_lower = llava_text.lower() # result = raw_lower.replace("green", "").strip() # sent = result.translate(str.maketrans('', '', string.punctuation)) # if "a " in sent: # llava_text = sent.replace("a ", "") # elif "an " in sent: # llava_text = sent.replace("an ", "") # # print("llava_text:", llava_text) # # 得到obj_name # obj_name = anno2["text"] # obj_name = re.sub(r'_\d+$', '', obj_name) # # print("obj_name:", obj_name) # sample = { # "llava_text": llava_text, # "obj_name": obj_name # } # to_save.append(sample) # save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavetext.json" # with open(save_path, "w") as fp: # json.dump(to_save, fp) # 创建方便human阅读的objname、llava_text文件 import random json_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json" with open(json_path, "r") as fp: datas = json.load(fp) to_save = [] for data in datas: annos = data["first_frame_anns"] for anno in annos: obj_name = anno["obj_name"] llava_text = anno["llava_text"] sample = { "llava_text": llava_text, "obj_name": obj_name } to_save.append(sample) to_save = random.sample(to_save, 600) save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavatext_4human.json" # 将数据逐行写入 JSON 文件 with open(save_path, "w") as fp: for data in to_save: json.dump(data, fp) fp.write("\n") # 每行一个字典