ObjectRelator-Original / datasets /select_similarity_objname_llavatext.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import json
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import re
import string
# json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_instruction.json"
# json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_objname_instruction.json"
# with open(json_path_1, "r") as fp:
# datas_1 = json.load(fp)
# with open(json_path_2, "r") as fp:
# datas_2 = json.load(fp)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.5
# 加载预训练的Sentence-BERT模型
model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
model.eval() # 设置模型为评估模式
def extract_object_name(text):
parts = text.split("is")
if len(parts) > 1:
return parts[1].strip()
return None
def get_sbert_embedding(text):
"""
使用Sentence-BERT提取文本特征向量
"""
with torch.no_grad():
embedding = model.encode(text, convert_to_tensor=True, device=device)
return embedding
def calculate_cosine_similarity(embedding1, embedding2):
"""
计算两个特征向量的余弦相似性
"""
similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
return similarity.item()
# # 统计正确样本数
# correct_count = 0
# # 遍历数据,计算相似性
# similarity_list = []
# # 总物体数目
# obj_total = 0
# datas_save_1 = []
# datas_save_2 = []
# for data1, data2 in zip(datas_1, datas_2):
# score = 0
# annos_1 = data1["first_frame_anns"]
# annos_2 = data2["first_frame_anns"]
# for anno1, anno2 in zip(annos_1, annos_2):
# obj_total += 1
# # 得到llava_text
# llava_text = anno1["text"]
# llava_text = extract_object_name(llava_text)
# raw_lower = llava_text.lower()
# result = raw_lower.replace("green", "").strip()
# sent = result.translate(str.maketrans('', '', string.punctuation))
# if "a " in sent:
# llava_text = sent.replace("a ", "")
# elif "an " in sent:
# llava_text = sent.replace("an ", "")
# print("llava_text:", llava_text)
# # 得到obj_name
# obj_name = anno2["text"]
# obj_name = re.sub(r'_\d+$', '', obj_name)
# print("obj_name:", obj_name)
# # 提取特征
# obj_embedding = get_sbert_embedding(obj_name)
# llava_embedding = get_sbert_embedding(llava_text)
# # 计算相似性
# similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
# print("similarity:", similarity)
# similarity_list.append(similarity)
# score += similarity
# # 判断是否为正确样本
# if similarity > SIMILARITY_THRESHOLD:
# correct_count += 1
# score_avg = score / len(annos_1)
# print("score_avg:", score_avg)
# if score_avg > SIMILARITY_THRESHOLD:
# datas_save_1.append(data1)
# datas_save_2.append(data2)
# # 计算正确样本比例
# total_samples_before = len(datas_1)
# print("num before:", total_samples_before)
# accuracy = correct_count / obj_total
# average_similarity = sum(similarity_list) / obj_total
# print(f"正确样本数: {correct_count}")
# print(f"总样本数: {obj_total}")
# print(f"正确样本比例: {accuracy:.2%}")
# print(f"平均相似性: {average_similarity:.4f}")
# print("after filter num:", len(datas_save_1))
# save_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json"
# save_paht_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json"
# with open(save_path_1, "w") as fp:
# json.dump(datas_save_1, fp)
# with open(save_paht_2, "w") as fp:
# json.dump(datas_save_2, fp)
# json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json"
# json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json"
# with open(json_path_1, "r") as fp:
# datas_1 = json.load(fp)
# with open(json_path_2, "r") as fp:
# datas_2 = json.load(fp)
# # 为方便比较,只提取json文件中的obj_name和llava_text
# to_save = []
# for data1, data2 in zip(datas_1, datas_2):
# annos_1 = data1["first_frame_anns"]
# annos_2 = data2["first_frame_anns"]
# for anno1, anno2 in zip(annos_1, annos_2):
# # 得到llava_text
# llava_text = anno1["text"]
# llava_text = extract_object_name(llava_text)
# raw_lower = llava_text.lower()
# result = raw_lower.replace("green", "").strip()
# sent = result.translate(str.maketrans('', '', string.punctuation))
# if "a " in sent:
# llava_text = sent.replace("a ", "")
# elif "an " in sent:
# llava_text = sent.replace("an ", "")
# # print("llava_text:", llava_text)
# # 得到obj_name
# obj_name = anno2["text"]
# obj_name = re.sub(r'_\d+$', '', obj_name)
# # print("obj_name:", obj_name)
# sample = {
# "llava_text": llava_text,
# "obj_name": obj_name
# }
# to_save.append(sample)
# save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavetext.json"
# with open(save_path, "w") as fp:
# json.dump(to_save, fp)
# 创建方便human阅读的objname、llava_text文件
import random
json_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json"
with open(json_path, "r") as fp:
datas = json.load(fp)
to_save = []
for data in datas:
annos = data["first_frame_anns"]
for anno in annos:
obj_name = anno["obj_name"]
llava_text = anno["llava_text"]
sample = {
"llava_text": llava_text,
"obj_name": obj_name
}
to_save.append(sample)
to_save = random.sample(to_save, 600)
save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavatext_4human.json"
# 将数据逐行写入 JSON 文件
with open(save_path, "w") as fp:
for data in to_save:
json.dump(data, fp)
fp.write("\n") # 每行一个字典