File size: 2,984 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# data_path = "/work/yuqian_fu/Ego/data_segswap/check_text_byname_600_objname_llavaname.json"
# with open(data_path, "r") as fp:
#     datas = json.load(fp)

data = [
    {"obj_name": "apple", "llava_text": "wheel"},
    {"obj_name": "car", "llava_text": "engine"},
    # 添加更多数据...
]

# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.8

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased").to(device)
model.eval()  # 设置模型为评估模式

def get_bert_embedding(text):
    """
    使用BERT提取文本特征向量
    """
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    inputs = {key: value.to(device) for key, value in inputs.items()}  # 将输入张量移动到 GPU
    with torch.no_grad():
        outputs = model(**inputs)
    # 使用最后一层的[CLS]标记的输出作为句子嵌入
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    return cls_embedding

def calculate_cosine_similarity(embedding1, embedding2):
    """
    计算两个特征向量的余弦相似性
    """
    similarity = F.cosine_similarity(embedding1, embedding2)
    return similarity.item()

# 统计正确样本数
correct_count = 0
# 遍历数据,计算相似性
similarity_list = []

# for data in datas:
#     annos = data["first_frame_anns"]
#     for anno in annos:
#         obj_name = anno["obj_name"]
#         llava_text = anno["llava_text"]
        
#         # 提取特征
#         obj_embedding = get_bert_embedding(obj_name)
#         llava_embedding = get_bert_embedding(llava_text)
        
#         # 计算相似性
#         similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
#         similarity_list.append(similarity)
        
#         # 判断是否为正确样本
#         if similarity > SIMILARITY_THRESHOLD:
#             correct_count += 1


for item in data:
    obj_name = item["obj_name"]
    llava_text = item["llava_text"]

    # 提取特征
    obj_embedding = get_bert_embedding(obj_name)
    llava_embedding = get_bert_embedding(llava_text)
    
    # 计算相似性
    similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
    print("similarity:", similarity)
    similarity_list.append(similarity)
    
    # 判断是否为正确样本
    if similarity > SIMILARITY_THRESHOLD:
        correct_count += 1


# 计算正确样本比例
total_samples = len(data)
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples

print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")