|
|
from transformers import BertTokenizer, BertModel |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = [ |
|
|
{"obj_name": "apple", "llava_text": "wheel"}, |
|
|
{"obj_name": "car", "llava_text": "engine"}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.8 |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
model = BertModel.from_pretrained("bert-base-uncased").to(device) |
|
|
model.eval() |
|
|
|
|
|
def get_bert_embedding(text): |
|
|
""" |
|
|
使用BERT提取文本特征向量 |
|
|
""" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
|
return cls_embedding |
|
|
|
|
|
def calculate_cosine_similarity(embedding1, embedding2): |
|
|
""" |
|
|
计算两个特征向量的余弦相似性 |
|
|
""" |
|
|
similarity = F.cosine_similarity(embedding1, embedding2) |
|
|
return similarity.item() |
|
|
|
|
|
|
|
|
correct_count = 0 |
|
|
|
|
|
similarity_list = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for item in data: |
|
|
obj_name = item["obj_name"] |
|
|
llava_text = item["llava_text"] |
|
|
|
|
|
|
|
|
obj_embedding = get_bert_embedding(obj_name) |
|
|
llava_embedding = get_bert_embedding(llava_text) |
|
|
|
|
|
|
|
|
similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) |
|
|
print("similarity:", similarity) |
|
|
similarity_list.append(similarity) |
|
|
|
|
|
|
|
|
if similarity > SIMILARITY_THRESHOLD: |
|
|
correct_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
total_samples = len(data) |
|
|
accuracy = correct_count / total_samples |
|
|
average_similarity = sum(similarity_list) / total_samples |
|
|
|
|
|
print(f"正确样本数: {correct_count}") |
|
|
print(f"总样本数: {total_samples}") |
|
|
print(f"正确样本比例: {accuracy:.2%}") |
|
|
print(f"平均相似性: {average_similarity:.4f}") |