from transformers import BertTokenizer, BertModel import torch import torch.nn.functional as F import json device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # data_path = "/work/yuqian_fu/Ego/data_segswap/check_text_byname_600_objname_llavaname.json" # with open(data_path, "r") as fp: # datas = json.load(fp) data = [ {"obj_name": "apple", "llava_text": "wheel"}, {"obj_name": "car", "llava_text": "engine"}, # 添加更多数据... ] # 定义相似性阈值 SIMILARITY_THRESHOLD = 0.8 # 加载预训练的BERT模型和分词器 tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") model = BertModel.from_pretrained("bert-base-uncased").to(device) model.eval() # 设置模型为评估模式 def get_bert_embedding(text): """ 使用BERT提取文本特征向量 """ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) inputs = {key: value.to(device) for key, value in inputs.items()} # 将输入张量移动到 GPU with torch.no_grad(): outputs = model(**inputs) # 使用最后一层的[CLS]标记的输出作为句子嵌入 cls_embedding = outputs.last_hidden_state[:, 0, :] return cls_embedding def calculate_cosine_similarity(embedding1, embedding2): """ 计算两个特征向量的余弦相似性 """ similarity = F.cosine_similarity(embedding1, embedding2) return similarity.item() # 统计正确样本数 correct_count = 0 # 遍历数据,计算相似性 similarity_list = [] # for data in datas: # annos = data["first_frame_anns"] # for anno in annos: # obj_name = anno["obj_name"] # llava_text = anno["llava_text"] # # 提取特征 # obj_embedding = get_bert_embedding(obj_name) # llava_embedding = get_bert_embedding(llava_text) # # 计算相似性 # similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) # similarity_list.append(similarity) # # 判断是否为正确样本 # if similarity > SIMILARITY_THRESHOLD: # correct_count += 1 for item in data: obj_name = item["obj_name"] llava_text = item["llava_text"] # 提取特征 obj_embedding = get_bert_embedding(obj_name) llava_embedding = get_bert_embedding(llava_text) # 计算相似性 similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) print("similarity:", similarity) similarity_list.append(similarity) # 判断是否为正确样本 if similarity > SIMILARITY_THRESHOLD: correct_count += 1 # 计算正确样本比例 total_samples = len(data) accuracy = correct_count / total_samples average_similarity = sum(similarity_list) / total_samples print(f"正确样本数: {correct_count}") print(f"总样本数: {total_samples}") print(f"正确样本比例: {accuracy:.2%}") print(f"平均相似性: {average_similarity:.4f}")