File size: 6,243 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import json
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import re
import string

# json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_instruction.json"
# json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_framelevel_newprompt_all_objname_instruction.json"

# with open(json_path_1, "r") as fp:
#     datas_1 = json.load(fp)
# with open(json_path_2, "r") as fp:
#     datas_2 = json.load(fp)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 定义相似性阈值
SIMILARITY_THRESHOLD = 0.5

# 加载预训练的Sentence-BERT模型
model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
model.eval()  # 设置模型为评估模式

def extract_object_name(text):
    parts = text.split("is")
    if len(parts) > 1:
        return parts[1].strip()
    return None


def get_sbert_embedding(text):
    """
    使用Sentence-BERT提取文本特征向量
    """
    with torch.no_grad():
        embedding = model.encode(text, convert_to_tensor=True, device=device)
    return embedding

def calculate_cosine_similarity(embedding1, embedding2):
    """
    计算两个特征向量的余弦相似性
    """
    similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
    return similarity.item()

# # 统计正确样本数
# correct_count = 0
# # 遍历数据,计算相似性
# similarity_list = []
# # 总物体数目
# obj_total = 0

# datas_save_1 = []
# datas_save_2 = []
# for data1, data2 in zip(datas_1, datas_2):
#     score = 0
#     annos_1 = data1["first_frame_anns"]
#     annos_2 = data2["first_frame_anns"]
#     for anno1, anno2 in zip(annos_1, annos_2):
#         obj_total += 1

#         # 得到llava_text
#         llava_text = anno1["text"]
#         llava_text = extract_object_name(llava_text)
#         raw_lower = llava_text.lower()
#         result = raw_lower.replace("green", "").strip() 
#         sent = result.translate(str.maketrans('', '', string.punctuation))
#         if "a " in sent:
#             llava_text = sent.replace("a ", "")
#         elif "an " in sent:
#             llava_text = sent.replace("an ", "")
#         print("llava_text:", llava_text)

#         # 得到obj_name
#         obj_name = anno2["text"]
#         obj_name = re.sub(r'_\d+$', '', obj_name) 
#         print("obj_name:", obj_name) 

#         # 提取特征
#         obj_embedding = get_sbert_embedding(obj_name)
#         llava_embedding = get_sbert_embedding(llava_text)
        
#         # 计算相似性
#         similarity = calculate_cosine_similarity(obj_embedding, llava_embedding)
#         print("similarity:", similarity)
#         similarity_list.append(similarity)
        
#         score += similarity
#         # 判断是否为正确样本
#         if similarity > SIMILARITY_THRESHOLD:
#             correct_count += 1

#     score_avg = score / len(annos_1)
#     print("score_avg:", score_avg)
#     if score_avg > SIMILARITY_THRESHOLD:
#         datas_save_1.append(data1)
#         datas_save_2.append(data2)
        

# # 计算正确样本比例
# total_samples_before = len(datas_1)
# print("num before:", total_samples_before)
# accuracy = correct_count / obj_total
# average_similarity = sum(similarity_list) / obj_total

# print(f"正确样本数: {correct_count}")
# print(f"总样本数: {obj_total}")
# print(f"正确样本比例: {accuracy:.2%}")
# print(f"平均相似性: {average_similarity:.4f}")
# print("after filter num:", len(datas_save_1))


# save_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json"
# save_paht_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json"

# with open(save_path_1, "w") as fp:
#     json.dump(datas_save_1, fp)
# with open(save_paht_2, "w") as fp:
#     json.dump(datas_save_2, fp)

# json_path_1 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_llavatext_similarity_new.json"
# json_path_2 = "/home/yuqian_fu/Projects/PSALM/egoexo_val_objname_similarity_new.json"

# with open(json_path_1, "r") as fp:
#     datas_1 = json.load(fp)
# with open(json_path_2, "r") as fp:
#     datas_2 = json.load(fp)

# # 为方便比较,只提取json文件中的obj_name和llava_text
# to_save = []
# for data1, data2 in zip(datas_1, datas_2):
#     annos_1 = data1["first_frame_anns"]
#     annos_2 = data2["first_frame_anns"]
#     for anno1, anno2 in zip(annos_1, annos_2):
        
#         # 得到llava_text
#         llava_text = anno1["text"]
#         llava_text = extract_object_name(llava_text)
#         raw_lower = llava_text.lower()
#         result = raw_lower.replace("green", "").strip() 
#         sent = result.translate(str.maketrans('', '', string.punctuation))
#         if "a " in sent:
#             llava_text = sent.replace("a ", "")
#         elif "an " in sent:
#             llava_text = sent.replace("an ", "")
#         # print("llava_text:", llava_text)

#         # 得到obj_name
#         obj_name = anno2["text"]
#         obj_name = re.sub(r'_\d+$', '', obj_name) 
#         # print("obj_name:", obj_name) 

#         sample = {

#             "llava_text": llava_text,
#             "obj_name": obj_name
#         }
#         to_save.append(sample)

# save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavetext.json"

# with open(save_path, "w") as fp:
#     json.dump(to_save, fp)








# 创建方便human阅读的objname、llava_text文件
import random

json_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json"
with open(json_path, "r") as fp:
    datas = json.load(fp)


to_save = []
for data in datas:
    annos = data["first_frame_anns"]
    for anno in annos:
        obj_name = anno["obj_name"]
        llava_text = anno["llava_text"]
        sample = {
            "llava_text": llava_text,
            "obj_name": obj_name
        }
        to_save.append(sample)

to_save = random.sample(to_save, 600)

save_path = "/home/yuqian_fu/Projects/PSALM/only_objname_llavatext_4human.json"
# 将数据逐行写入 JSON 文件
with open(save_path, "w") as fp:
    for data in to_save:
        json.dump(data, fp)
        fp.write("\n")  # 每行一个字典