| import json | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import torch.nn.functional as F | |
| import re | |
| import string | |
| # json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_instruction.json" | |
| # json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_objname_instruction.json" | |
| # with open(json_path_1, "r") as fp: | |
| # datas_1 = json.load(fp) | |
| # with open(json_path_2, "r") as fp: | |
| # datas_2 = json.load(fp) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # 定义相似性阈值 | |
| SIMILARITY_THRESHOLD = 0.5 | |
| # 加载预训练的Sentence-BERT模型 | |
| model = SentenceTransformer("all-MiniLM-L6-v2").to(device) | |
| model.eval() # 设置模型为评估模式 | |
| def extract_object_name(text): | |
| parts = text.split("is") | |
| if len(parts) > 1: | |
| return parts[1].strip() | |
| return None | |
| def get_sbert_embedding(text): | |
| """ | |
| 使用Sentence-BERT提取文本特征向量 | |
| """ | |
| with torch.no_grad(): | |
| embedding = model.encode(text, convert_to_tensor=True, device=device) | |
| return embedding | |
| def calculate_cosine_similarity(embedding1, embedding2): | |
| """ | |
| 计算两个特征向量的余弦相似性 | |
| """ | |
| similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0)) | |
| return similarity.item() | |
| # # 统计正确样本数 | |
| # correct_count = 0 | |
| # # 遍历数据,计算相似性 | |
| # similarity_list = [] | |
| # # 总物体数目 | |
| # obj_total = 0 | |
| # datas_save_1 = [] | |
| # datas_save_2 = [] | |
| # for data1, data2 in zip(datas_1, datas_2): | |
| # score = 0 | |
| # annos_1 = data1["first_frame_anns"] | |
| # annos_2 = data2["first_frame_anns"] | |
| # for anno1, anno2 in zip(annos_1, annos_2): | |
| # obj_total += 1 | |
| # # 得到llava_text | |
| # llava_text = anno1["text"] | |
| # llava_text = extract_object_name(llava_text) | |
| # raw_lower = llava_text.lower() | |
| # result = raw_lower.replace("green", "").strip() | |
| # sent = result.translate(str.maketrans('', '', string.punctuation)) | |
| # if "a " in sent: | |
| # llava_text = sent.replace("a ", "") | |
| # elif "an " in sent: | |
| # llava_text = sent.replace("an ", "") | |
| # print("llava_text:", llava_text) | |
| # # 得到obj_name | |
| # obj_name = anno2["text"] | |
| # obj_name = re.sub(r'_\d+$', '', obj_name) | |
| # print("obj_name:", obj_name) | |
| # # 提取特征 | |
| # obj_embedding = get_sbert_embedding(obj_name) | |
| # llava_embedding = get_sbert_embedding(llava_text) | |
| # # 计算相似性 | |
| # similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) | |
| # print("similarity:", similarity) | |
| # similarity_list.append(similarity) | |
| # score += similarity | |
| # # 判断是否为正确样本 | |
| # if similarity > SIMILARITY_THRESHOLD: | |
| # correct_count += 1 | |
| # score_avg = score / len(annos_1) | |
| # print("score_avg:", score_avg) | |
| # if score_avg > SIMILARITY_THRESHOLD: | |
| # datas_save_1.append(data1) | |
| # datas_save_2.append(data2) | |
| # # 计算正确样本比例 | |
| # total_samples_before = len(datas_1) | |
| # print("num before:", total_samples_before) | |
| # accuracy = correct_count / obj_total | |
| # average_similarity = sum(similarity_list) / obj_total | |
| # print(f"正确样本数: {correct_count}") | |
| # print(f"总样本数: {obj_total}") | |
| # print(f"正确样本比例: {accuracy:.2%}") | |
| # print(f"平均相似性: {average_similarity:.4f}") | |
| # print("after filter num:", len(datas_save_1)) | |
| # save_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json" | |
| # save_paht_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json" | |
| # with open(save_path_1, "w") as fp: | |
| # json.dump(datas_save_1, fp) | |
| # with open(save_paht_2, "w") as fp: | |
| # json.dump(datas_save_2, fp) | |
| # json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json" | |
| # json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json" | |
| # with open(json_path_1, "r") as fp: | |
| # datas_1 = json.load(fp) | |
| # with open(json_path_2, "r") as fp: | |
| # datas_2 = json.load(fp) | |
| # # 为方便比较,只提取json文件中的obj_name和llava_text | |
| # to_save = [] | |
| # for data1, data2 in zip(datas_1, datas_2): | |
| # annos_1 = data1["first_frame_anns"] | |
| # annos_2 = data2["first_frame_anns"] | |
| # for anno1, anno2 in zip(annos_1, annos_2): | |
| # # 得到llava_text | |
| # llava_text = anno1["text"] | |
| # llava_text = extract_object_name(llava_text) | |
| # raw_lower = llava_text.lower() | |
| # result = raw_lower.replace("green", "").strip() | |
| # sent = result.translate(str.maketrans('', '', string.punctuation)) | |
| # if "a " in sent: | |
| # llava_text = sent.replace("a ", "") | |
| # elif "an " in sent: | |
| # llava_text = sent.replace("an ", "") | |
| # # print("llava_text:", llava_text) | |
| # # 得到obj_name | |
| # obj_name = anno2["text"] | |
| # obj_name = re.sub(r'_\d+$', '', obj_name) | |
| # # print("obj_name:", obj_name) | |
| # sample = { | |
| # "llava_text": llava_text, | |
| # "obj_name": obj_name | |
| # } | |
| # to_save.append(sample) | |
| # save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavetext.json" | |
| # with open(save_path, "w") as fp: | |
| # json.dump(to_save, fp) | |
| # 创建方便human阅读的objname、llava_text文件 | |
| import random | |
| json_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json" | |
| with open(json_path, "r") as fp: | |
| datas = json.load(fp) | |
| to_save = [] | |
| for data in datas: | |
| annos = data["first_frame_anns"] | |
| for anno in annos: | |
| obj_name = anno["obj_name"] | |
| llava_text = anno["llava_text"] | |
| sample = { | |
| "llava_text": llava_text, | |
| "obj_name": obj_name | |
| } | |
| to_save.append(sample) | |
| to_save = random.sample(to_save, 600) | |
| save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavatext_4human.json" | |
| # 将数据逐行写入 JSON 文件 | |
| with open(save_path, "w") as fp: | |
| for data in to_save: | |
| json.dump(data, fp) | |
| fp.write("\n") # 每行一个字典 | |