File size: 2,130 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from gensim.models import KeyedVectors
import json

# 加载预训练的 Word2Vec 模型(Google News 数据集)
# 下载地址:https://code.google.com/archive/p/word2vec/
# 确保将路径替换为实际的模型文件路径
model_path = "/home/yuqian_fu/Projects/VQA/GoogleNews-vectors-negative300.bin.gz"
word2vec_model = KeyedVectors.load_word2vec_format(model_path, binary=True)

# 示例数据
data = [
    {"obj_name": "bicycle chain", "llava_text": "bicycle wheel"},
    {"obj_name": "cup", "llava_text": "mug"},
    {"obj_name": "pot", "llava_text": "pan"},
    # 添加更多数据...
]

# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.8

def get_word2vec_embedding(word):
    """
    获取单词的 Word2Vec 嵌入向量
    """
    try:
        return word2vec_model[word]
    except KeyError:
        print(f"Word '{word}' not found in Word2Vec vocabulary.")
        return None

def calculate_cosine_similarity(vec1, vec2):
    """
    计算两个向量的余弦相似性
    """
    return word2vec_model.cosine_similarities(vec1, [vec2])[0]

# 统计正确样本数
correct_count = 0
similarity_list = []

for item in data:
    obj_name = item["obj_name"]
    llava_text = item["llava_text"]

    # 提取特征
    obj_vec = get_word2vec_embedding(obj_name)
    llava_vec = get_word2vec_embedding(llava_text)

    if obj_vec is None or llava_vec is None:
        similarity = 0.0  # 如果任一单词不在词汇表中,相似性设为 0
    else:
        # 计算相似性
        similarity = calculate_cosine_similarity(obj_vec, llava_vec)

    print(f"Similarity between '{obj_name}' and '{llava_text}': {similarity:.4f}")
    similarity_list.append(similarity)

    # 判断是否为正确样本
    if similarity > SIMILARITY_THRESHOLD:
        correct_count += 1

# 计算正确样本比例
total_samples = len(data)
accuracy = correct_count / total_samples
average_similarity = sum(similarity_list) / total_samples

print(f"正确样本数: {correct_count}")
print(f"总样本数: {total_samples}")
print(f"正确样本比例: {accuracy:.2%}")
print(f"平均相似性: {average_similarity:.4f}")