ObjectRelator-Original / scripts /bert_similarity.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data_path = "/work/yuqian_fu/Ego/data_segswap/check_text_byname_600_objname_llavaname.json"
# with open(data_path, "r") as fp:
# datas = json.load(fp)
data = [
{"obj_name": "apple", "llava_text": "wheel"},
{"obj_name": "car", "llava_text": "engine"},
# 添加更多数据...
]
# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.8
# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased").to(device)
model.eval() # 设置模型为评估模式
def get_bert_embedding(text):
"""
使用BERT提取文本特征向量
"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
inputs = {key: value.to(device) for key, value in inputs.items()} # 将输入张量移动到 GPU
with torch.no_grad():
outputs = model(**inputs)
# 使用最后一层的[CLS]标记的输出作为句子嵌入
cls_embedding = outputs.last_hidden_state[:, 0, :]
return cls_embedding
def calculate_cosine_similarity(embedding1, embedding2):
"""
计算两个特征向量的余弦相似性
"""
similarity = F.cosine_similarity(embedding1, embedding2)
return similarity.item()
# 统计正确样本数
correct_count = 0
# 遍历数据,计算相似性
similarity_list = []
# for data in datas:
# annos = data["first_frame_anns"]
# for anno in annos:
# obj_name = anno["obj_name"]
# llava_text = anno["llava_text"]
# # 提取特征
# obj_embedding = get_bert_embedding(obj_name)
# llava_embedding = get_bert_embedding(llava_text)
# # 计算相似性
# similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
# similarity_list.append(similarity)
# # 判断是否为正确样本
# if similarity > SIMILARITY_THRESHOLD:
# correct_count += 1
for item in data:
obj_name = item["obj_name"]
llava_text = item["llava_text"]
# 提取特征
obj_embedding = get_bert_embedding(obj_name)
llava_embedding = get_bert_embedding(llava_text)
# 计算相似性
similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
print("similarity:", similarity)
similarity_list.append(similarity)
# 判断是否为正确样本
if similarity > SIMILARITY_THRESHOLD:
correct_count += 1
# 计算正确样本比例
total_samples = len(data)
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples
print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")