ObjectRelator-Original / scripts /sbert_similarity.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json"
#data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct_fixmusic.json"
#data_path = "/home/yuqian_fu/Projects/PSALM/check_text_byname_600_objname_llavaname_correct.json" #最初的一版
with open(data_path, "r") as f:
datas = json.load(f)
# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.5
# 加载预训练的Sentence-BERT模型
model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
#model = SentenceTransformer("all-mpnet-base-v2").to(device) # 使用MPNet模型
model.eval() # 设置模型为评估模式
def get_sbert_embedding(text):
"""
使用Sentence-BERT提取文本特征向量
"""
with torch.no_grad():
embedding = model.encode(text, convert_to_tensor=True, device=device)
return embedding
def calculate_cosine_similarity(embedding1, embedding2):
"""
计算两个特征向量的余弦相似性
"""
similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
return similarity.item()
# 统计正确样本数
correct_count = 0
num_total = 0
# 遍历数据,计算相似性
similarity_list = []
'''v1:以物体个数为分母'''
# for data in datas:
# annos = data["first_frame_anns"]
# for anno in annos:
# num_total += 1
# obj_name = anno["obj_name"]
# llava_text = anno["llava_text"]
# # 提取特征
# obj_embedding = get_sbert_embedding(obj_name)
# llava_embedding = get_sbert_embedding(llava_text)
# # 计算相似性
# similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
# print("similarity:", similarity)
# similarity_list.append(similarity)
# # 判断是否为正确样本
# if similarity > SIMILARITY_THRESHOLD:
# correct_count += 1
'''v2:以视频帧数为分母'''
for data in datas:
annos = data["first_frame_anns"]
score = 0
for anno in annos:
obj_name = anno["obj_name"]
llava_text = anno["llava_text"]
# 提取特征
obj_embedding = get_sbert_embedding(obj_name)
llava_embedding = get_sbert_embedding(llava_text)
# 计算相似性
similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
score += similarity
sim_avg = score / len(annos)
similarity_list.append(sim_avg)
if sim_avg > SIMILARITY_THRESHOLD:
correct_count += 1
# 计算正确样本比例
# print("total:", num_total)
total_samples = len(datas)
# total_samples = num_total
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples
print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")