ObjectRelator-Original / scripts /word2vec_similarity.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
from gensim.models import KeyedVectors
import json
# 加载预训练的 Word2Vec 模型(Google News 数据集)
# 下载地址:https://code.google.com/archive/p/word2vec/
# 确保将路径替换为实际的模型文件路径
model_path = "/home/yuqian_fu/Projects/VQA/GoogleNews-vectors-negative300.bin.gz"
word2vec_model = KeyedVectors.load_word2vec_format(model_path, binary=True)
# 示例数据
data = [
{"obj_name": "bicycle chain", "llava_text": "bicycle wheel"},
{"obj_name": "cup", "llava_text": "mug"},
{"obj_name": "pot", "llava_text": "pan"},
# 添加更多数据...
]
# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.8
def get_word2vec_embedding(word):
"""
获取单词的 Word2Vec 嵌入向量
"""
try:
return word2vec_model[word]
except KeyError:
print(f"Word '{word}' not found in Word2Vec vocabulary.")
return None
def calculate_cosine_similarity(vec1, vec2):
"""
计算两个向量的余弦相似性
"""
return word2vec_model.cosine_similarities(vec1, [vec2])[0]
# 统计正确样本数
correct_count = 0
similarity_list = []
for item in data:
obj_name = item["obj_name"]
llava_text = item["llava_text"]
# 提取特征
obj_vec = get_word2vec_embedding(obj_name)
llava_vec = get_word2vec_embedding(llava_text)
if obj_vec is None or llava_vec is None:
similarity = 0.0 # 如果任一单词不在词汇表中,相似性设为 0
else:
# 计算相似性
similarity = calculate_cosine_similarity(obj_vec, llava_vec)
print(f"Similarity between '{obj_name}' and '{llava_text}': {similarity:.4f}")
similarity_list.append(similarity)
# 判断是否为正确样本
if similarity > SIMILARITY_THRESHOLD:
correct_count += 1
# 计算正确样本比例
total_samples = len(data)
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples
print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")