|
|
from gensim.models import KeyedVectors |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = "/home/yuqian_fu/Projects/VQA/GoogleNews-vectors-negative300.bin.gz" |
|
|
word2vec_model = KeyedVectors.load_word2vec_format(model_path, binary=True) |
|
|
|
|
|
|
|
|
data = [ |
|
|
{"obj_name": "bicycle chain", "llava_text": "bicycle wheel"}, |
|
|
{"obj_name": "cup", "llava_text": "mug"}, |
|
|
{"obj_name": "pot", "llava_text": "pan"}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.8 |
|
|
|
|
|
def get_word2vec_embedding(word): |
|
|
""" |
|
|
获取单词的 Word2Vec 嵌入向量 |
|
|
""" |
|
|
try: |
|
|
return word2vec_model[word] |
|
|
except KeyError: |
|
|
print(f"Word '{word}' not found in Word2Vec vocabulary.") |
|
|
return None |
|
|
|
|
|
def calculate_cosine_similarity(vec1, vec2): |
|
|
""" |
|
|
计算两个向量的余弦相似性 |
|
|
""" |
|
|
return word2vec_model.cosine_similarities(vec1, [vec2])[0] |
|
|
|
|
|
|
|
|
correct_count = 0 |
|
|
similarity_list = [] |
|
|
|
|
|
for item in data: |
|
|
obj_name = item["obj_name"] |
|
|
llava_text = item["llava_text"] |
|
|
|
|
|
|
|
|
obj_vec = get_word2vec_embedding(obj_name) |
|
|
llava_vec = get_word2vec_embedding(llava_text) |
|
|
|
|
|
if obj_vec is None or llava_vec is None: |
|
|
similarity = 0.0 |
|
|
else: |
|
|
|
|
|
similarity = calculate_cosine_similarity(obj_vec, llava_vec) |
|
|
|
|
|
print(f"Similarity between '{obj_name}' and '{llava_text}': {similarity:.4f}") |
|
|
similarity_list.append(similarity) |
|
|
|
|
|
|
|
|
if similarity > SIMILARITY_THRESHOLD: |
|
|
correct_count += 1 |
|
|
|
|
|
|
|
|
total_samples = len(data) |
|
|
accuracy = correct_count / total_samples |
|
|
average_similarity = sum(similarity_list) / total_samples |
|
|
|
|
|
print(f"正确样本数: {correct_count}") |
|
|
print(f"总样本数: {total_samples}") |
|
|
print(f"正确样本比例: {accuracy:.2%}") |
|
|
print(f"平均相似性: {average_similarity:.4f}") |