File size: 3,152 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json"
#data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct_fixmusic.json"
#data_path = "/home/yuqian_fu/Projects/PSALM/check_text_byname_600_objname_llavaname_correct.json" #最初的一版
with open(data_path, "r") as f:
    datas = json.load(f)

# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.5

# 加载预训练的Sentence-BERT模型
model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
#model = SentenceTransformer("all-mpnet-base-v2").to(device)  # 使用MPNet模型
model.eval()  # 设置模型为评估模式

def get_sbert_embedding(text):
    """
    使用Sentence-BERT提取文本特征向量
    """
    with torch.no_grad():
        embedding = model.encode(text, convert_to_tensor=True, device=device)
    return embedding

def calculate_cosine_similarity(embedding1, embedding2):
    """
    计算两个特征向量的余弦相似性
    """
    similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
    return similarity.item()

# 统计正确样本数
correct_count = 0
num_total = 0
# 遍历数据,计算相似性
similarity_list = []

'''v1:以物体个数为分母'''
# for data in datas:
#     annos = data["first_frame_anns"]
#     for anno in annos:
#         num_total += 1
#         obj_name = anno["obj_name"]
#         llava_text = anno["llava_text"]

#         # 提取特征
#         obj_embedding = get_sbert_embedding(obj_name)
#         llava_embedding = get_sbert_embedding(llava_text)
        
#         # 计算相似性
#         similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
#         print("similarity:", similarity)
#         similarity_list.append(similarity)
        
#         # 判断是否为正确样本
#         if similarity > SIMILARITY_THRESHOLD:
#             correct_count += 1


'''v2:以视频帧数为分母'''
for data in datas:
    annos = data["first_frame_anns"]
    score = 0
    for anno in annos:
        obj_name = anno["obj_name"]
        llava_text = anno["llava_text"]

        # 提取特征
        obj_embedding = get_sbert_embedding(obj_name)
        llava_embedding = get_sbert_embedding(llava_text)
        
        # 计算相似性
        similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
        score += similarity
    
    sim_avg = score / len(annos)
    similarity_list.append(sim_avg)
    if sim_avg > SIMILARITY_THRESHOLD:
        correct_count += 1
    


# 计算正确样本比例
# print("total:", num_total)
total_samples = len(datas)
# total_samples = num_total
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples

print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")