|
|
from sentence_transformers import SentenceTransformer |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json" |
|
|
|
|
|
|
|
|
with open(data_path, "r") as f: |
|
|
datas = json.load(f) |
|
|
|
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2").to(device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def get_sbert_embedding(text): |
|
|
""" |
|
|
使用Sentence-BERT提取文本特征向量 |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
embedding = model.encode(text, convert_to_tensor=True, device=device) |
|
|
return embedding |
|
|
|
|
|
def calculate_cosine_similarity(embedding1, embedding2): |
|
|
""" |
|
|
计算两个特征向量的余弦相似性 |
|
|
""" |
|
|
similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0)) |
|
|
return similarity.item() |
|
|
|
|
|
|
|
|
correct_count = 0 |
|
|
num_total = 0 |
|
|
|
|
|
similarity_list = [] |
|
|
|
|
|
'''v1:以物体个数为分母''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''v2:以视频帧数为分母''' |
|
|
for data in datas: |
|
|
annos = data["first_frame_anns"] |
|
|
score = 0 |
|
|
for anno in annos: |
|
|
obj_name = anno["obj_name"] |
|
|
llava_text = anno["llava_text"] |
|
|
|
|
|
|
|
|
obj_embedding = get_sbert_embedding(obj_name) |
|
|
llava_embedding = get_sbert_embedding(llava_text) |
|
|
|
|
|
|
|
|
similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) |
|
|
score += similarity |
|
|
|
|
|
sim_avg = score / len(annos) |
|
|
similarity_list.append(sim_avg) |
|
|
if sim_avg > SIMILARITY_THRESHOLD: |
|
|
correct_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_samples = len(datas) |
|
|
|
|
|
accuracy = correct_count / total_samples |
|
|
average_similarity = sum(similarity_list) / total_samples |
|
|
|
|
|
print(f"正确样本数: {correct_count}") |
|
|
print(f"总样本数: {total_samples}") |
|
|
print(f"正确样本比例: {accuracy:.2%}") |
|
|
print(f"平均相似性: {average_similarity:.4f}") |