MAS-AI-0000 commited on
Commit
2d84a53
·
verified ·
1 Parent(s): 31fbcf8

Upload 3 files

Browse files
detree/utils/dataset.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from .adversarial.alter_number import AlterNumbersAttack
8
+ from .adversarial.alternative_spelling import AlternativeSpellingAttack
9
+ from .adversarial.article_deletion import ArticleDeletionAttack
10
+ from .adversarial.homoglyph import HomoglyphAttack
11
+ from .adversarial.insert_paragraphs import InsertParagraphsAttack
12
+ from .adversarial.misspelling import MisspellingAttack
13
+ from .adversarial.upper_lower import UpperLowerFlipAttack
14
+ from .adversarial.whitespace import WhiteSpaceAttack
15
+ from .adversarial.zero_width_space import ZeroWidthSpaceAttack
16
+
17
+ model_alias_mapping = {
18
+ 'chatgpt': 'chatgpt',
19
+ 'ChatGPT': 'chatgpt',
20
+ 'chatGPT': 'chatgpt',
21
+ 'gpt-3.5-trubo': 'gpt-3.5-trubo',
22
+ 'GPT4': 'gpt4',
23
+ 'gpt4': 'gpt4',
24
+ 'text-davinci-002': 'text-davinci-002',
25
+ 'text-davinci-003': 'text-davinci-003',
26
+ 'davinci': 'text-davinci',
27
+ 'gpt1': 'gpt1',
28
+ 'gpt2_pytorch': 'gpt2-pytorch',
29
+ 'gpt2_large': 'gpt2-large',
30
+ 'gpt2_small': 'gpt2-small',
31
+ 'gpt2_medium': 'gpt2-medium',
32
+ 'gpt2-xl': 'gpt2-xl',
33
+ 'GPT2-XL': 'gpt2-xl',
34
+ 'gpt2_xl': 'gpt2-xl',
35
+ 'gpt2': 'gpt2-xl',
36
+ 'gpt3': 'gpt3',
37
+ 'GROVER_base': 'grover_base',
38
+ 'grover_base': 'grover_base',
39
+ 'grover_large': 'grover_large',
40
+ 'grover_mega': 'grover_mega',
41
+ 'llama2-fine-tuned': 'llama2',
42
+ 'opt_125m': 'opt_125m',
43
+ 'opt_1.3b': 'opt_1.3b',
44
+ 'opt_2.7b': 'opt_2.7b',
45
+ 'opt_6.7b': 'opt_6.7b',
46
+ 'opt_13b': 'opt_13b',
47
+ 'opt_30b': 'opt_30b',
48
+ 'opt_350m': 'opt_350m',
49
+ 'opt_iml_max_1.3b': 'opt_iml_max_1.3b',
50
+ 'opt_iml_30b': 'opt_iml_30b',
51
+ 'flan_t5_small': 'flan_t5_small',
52
+ 'flan_t5_base': 'flan_t5_base',
53
+ 'flan_t5_large': 'flan_t5_large',
54
+ 'flan_t5_xl': 'flan_t5_xl',
55
+ 'flan_t5_xxl': 'flan_t5_xxl',
56
+ 'flan_t5': 'flan_t5_xxl',
57
+ 'dolly': 'dolly',
58
+ 'GLM130B': 'GLM130B',
59
+ 'bloom_7b': 'bloom_7b',
60
+ 'bloomz': 'bloomz',
61
+ 't0_3b': 't0_3b',
62
+ 't0_11b': 't0_11b',
63
+ 'gpt_neox': 'gpt_neox',
64
+ 'xlm': 'xlm',
65
+ 'xlnet_large': 'xlnet_large',
66
+ 'xlnet_base': 'xlnet_base',
67
+ 'cohere': 'cohere',
68
+ 'ctrl': 'ctrl',
69
+ 'pplm_gpt2': 'pplm_gpt2',
70
+ 'pplm_distil': 'pplm_distil',
71
+ 'fair_wmt19': 'fair_wmt19',
72
+ 'fair_wmt20': 'fair_wmt20',
73
+ 'glm130b': 'GLM130B',
74
+ 'jais-30b': 'jais',
75
+ 'transfo_xl': 'transfo_xl',
76
+ '7B': '7B',
77
+ '13B': '13B',
78
+ '65B': '65B',
79
+ '30B': '30B',
80
+ 'gpt_j': 'gpt_j',
81
+ 'mpt': 'mpt',
82
+ 'mpt-chat': 'mpt-chat',
83
+ 'llama-chat': 'llama-chat',
84
+ 'mistral': 'mistral',
85
+ 'mistral-chat': 'mistral-chat',
86
+ 'cohere-chat': 'cohere-chat',
87
+ 'human': 'human',
88
+ }
89
+
90
+
91
+ def load_datapath(path,include_adversarial=False,dataset_name='all',include_attack=False):
92
+ data_path = {'train':[],'test':[]}
93
+ if dataset_name=='all':
94
+ datasets = os.listdir(path)
95
+ elif dataset_name=='M4':
96
+ datasets = ['M4_monolingual','M4_multilingual']
97
+ elif dataset_name=='RAID_all':
98
+ datasets = ['RAID','RAID_extra']
99
+ else:
100
+ datasets = [dataset_name]
101
+ for dataset in datasets:
102
+ dataset_path = os.path.join(path,dataset)
103
+ for adv in os.listdir(dataset_path):
104
+ if include_adversarial==False and 'no_attack' not in adv:
105
+ continue
106
+ if include_attack==False and ('perplexity_attack' in adv or 'synonym' in adv):
107
+ continue
108
+ adv_path = os.path.join(dataset_path,adv)
109
+ for data in os.listdir(adv_path):
110
+ if 'train' in data:
111
+ data_path['train'].append(os.path.join(adv_path,data))
112
+ elif 'test' in data:
113
+ data_path['test'].append(os.path.join(adv_path,data))
114
+ elif 'valid' in data:
115
+ if 'RAID' in dataset:
116
+ data_path['test'].append(os.path.join(adv_path,data))
117
+ else:
118
+ data_path['train'].append(os.path.join(adv_path,data))
119
+
120
+ return data_path
121
+
122
+ class TreeDataset(Dataset):
123
+ def __init__(self,data_path,need_ids=False):
124
+ self.data_path = data_path
125
+ self.need_ids=need_ids
126
+ self.dataset = self.load_data(data_path)
127
+
128
+ LLM_name=set()
129
+ for item in self.dataset:
130
+ name = model_alias_mapping[item['src']]
131
+ LLM_name.add(name)
132
+ self.classes = list(LLM_name)
133
+ self.classes = sorted(self.classes)
134
+
135
+ self.name2id={}
136
+ for i,name in enumerate(self.classes):
137
+ self.name2id[name]=i
138
+ self.human_id = self.name2id['human']
139
+
140
+ def load_jsonl(self,file_path):
141
+ out = []
142
+ add = ''
143
+ if 'paraphrase_by_llm' in file_path:
144
+ add='-paraphrase-qwen7B'
145
+ elif 'paraphrase' in file_path:
146
+ add='-paraphrase-dipper'
147
+ else:
148
+ assert 'no_attack' in file_path,file_path+'file path should contain no_attack or paraphrase'
149
+
150
+ with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
151
+ for line in jsonl_file:
152
+ now = json.loads(line)
153
+ if add != '':
154
+ if 'human' in now['src']:
155
+ continue
156
+ src = model_alias_mapping[now['src']]+add
157
+ if src not in model_alias_mapping:
158
+ model_alias_mapping[src]=src
159
+ now['src']=src
160
+ out.append(now)
161
+ return out
162
+
163
+ def load_data(self,data_path):
164
+ data = []
165
+ for path in data_path:
166
+ if 'no_attack' not in path and 'paraphrase' not in path:
167
+ continue
168
+ print(f'loading {path}')
169
+ data+=self.load_jsonl(path)
170
+ return data
171
+
172
+ def __len__(self):
173
+ return len(self.dataset)
174
+
175
+ def __getitem__(self, idx):
176
+ data_now = self.dataset[idx]
177
+ text = data_now['text']
178
+ label = data_now['label']
179
+ src = model_alias_mapping[data_now['src']]
180
+ src_id = self.name2id[src]
181
+ id = data_now['id']
182
+ if self.need_ids:
183
+ return text,int(label),int(src_id),int(id)
184
+ else:
185
+ return text,int(label),int(src_id)
186
+
187
+ class SCLDataset(Dataset):
188
+ def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True):
189
+ self.data_path = data_path
190
+ self.adv_p = adv_p
191
+ self.need_ids=need_ids
192
+ self.tokenizer = tokenizer
193
+ self.max_length = max_length
194
+ self.has_mix = has_mix
195
+
196
+ self.world_size = fabric.world_size
197
+ self.global_rank = fabric.global_rank
198
+ self.LLM_name=set()
199
+ dataset_len = self.get_data_len(data_path)
200
+
201
+ classes = sorted(list(self.LLM_name))
202
+ if name2id is None:
203
+ self.name2id={}
204
+ for i,name in enumerate(classes):
205
+ self.name2id[name]=i
206
+ else:
207
+ self.name2id = name2id
208
+ for name in classes:
209
+ assert name in self.name2id
210
+ self.classes = classes
211
+ print(f'there are {len(classes)} classes in dataset')
212
+ print(f'the classes are {classes}')
213
+
214
+ self.num_samples = math.ceil(dataset_len / self.world_size)
215
+ total_size = self.num_samples * self.world_size
216
+ indices = list(range(dataset_len))
217
+ padding_size = total_size - len(indices)
218
+ indices += indices[:padding_size]
219
+ assert len(indices) == total_size
220
+ indices = indices[self.global_rank : total_size : self.world_size]
221
+ assert len(indices) == self.num_samples
222
+ self.indices = set(indices)
223
+
224
+ data_dict = self.load_data(data_path)
225
+ self.dataset = [data_dict[i] for i in indices]
226
+ self.dataset_len = len(self.dataset)
227
+
228
+
229
+ def get_data_len(self,data_path):
230
+ total_len = 0
231
+ for path in data_path:
232
+ print(f'reading {path}')
233
+ with open(path, mode='r', encoding='utf-8') as jsonl_file:
234
+ for line in jsonl_file:
235
+ now = json.loads(line)
236
+ if now['src'] not in model_alias_mapping:
237
+ model_alias_mapping[now['src']]=now['src']
238
+ now['src'] = model_alias_mapping[now['src']]
239
+ if self.has_mix == False:
240
+ if 'human' in now['src'] and now['src'] != 'human':
241
+ continue
242
+ if now['src'] not in self.LLM_name:
243
+ self.LLM_name.add(now['src'])
244
+ total_len+=1
245
+ return total_len
246
+
247
+ def truncate_text(self,text):
248
+
249
+ tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)
250
+ truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
251
+ return truncated_text
252
+
253
+ def merge_dict(self,dict1,dict2):
254
+ for key in dict2:
255
+ dict1[key]=dict2[key]
256
+ return dict1
257
+
258
+ def load_jsonl(self,file_path,total_len):
259
+ out = {}
260
+ cnt=0
261
+ with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
262
+ for line in jsonl_file:
263
+ now = json.loads(line)
264
+ if self.has_mix == False:
265
+ if 'human' in now['src'] and now['src'] != 'human':
266
+ continue
267
+ if total_len+cnt in self.indices:
268
+ out[total_len+cnt]=now
269
+ cnt+=1
270
+ return out,cnt
271
+
272
+ def load_data(self,data_path):
273
+ data = {}
274
+ total_len = 0
275
+ for path in data_path:
276
+ print(f'loading {path}')
277
+ now_data,now_len=self.load_jsonl(path,total_len)
278
+ data = self.merge_dict(data,now_data)
279
+ total_len+=now_len
280
+ return data
281
+
282
+ def __len__(self):
283
+ return self.dataset_len
284
+
285
+ def __getitem__(self, idx):
286
+ data = self.dataset[idx]
287
+ text = data['text']
288
+ label = data['label']
289
+ src = self.name2id[model_alias_mapping[data['src']]]
290
+ id = data['id']
291
+
292
+ if random.random()<self.adv_p:
293
+ text = self.truncate_text(text)
294
+ attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\
295
+ HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack])
296
+ text = attack_method(text)
297
+ if self.need_ids:
298
+ return text,int(label),int(src),int(id)
299
+ return text,int(label),int(src)
300
+
301
+
302
+ class SCL_RM_Dataset(Dataset):
303
+ def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True,remove_cls=0.9):
304
+ self.data_path = data_path
305
+ self.adv_p = adv_p
306
+ self.need_ids=need_ids
307
+ self.tokenizer = tokenizer
308
+ self.max_length = max_length
309
+ self.has_mix = has_mix
310
+
311
+ self.world_size = fabric.world_size
312
+ self.global_rank = fabric.global_rank
313
+ self.LLM_name=set()
314
+ self.remove_cls = remove_cls
315
+ assert name2id is not None, 'name2id is None, please set name2id'
316
+ self.remove_name = set()
317
+ for name in name2id:
318
+ if random.random()<self.remove_cls and name != 'human':
319
+ self.remove_name.add(name)
320
+ dataset_len = self.get_data_len(data_path)
321
+
322
+ classes = sorted(list(self.LLM_name))
323
+ if name2id is None:
324
+ self.name2id={}
325
+ for i,name in enumerate(classes):
326
+ self.name2id[name]=i
327
+ else:
328
+ self.name2id = name2id
329
+ for name in classes:
330
+ assert name in self.name2id
331
+ self.classes = classes
332
+ print(f'there are {len(classes)} classes in dataset')
333
+ print(f'the classes are {classes}')
334
+
335
+ self.num_samples = math.ceil(dataset_len / self.world_size)
336
+ total_size = self.num_samples * self.world_size
337
+ indices = list(range(dataset_len))
338
+ padding_size = total_size - len(indices)
339
+ indices += indices[:padding_size]
340
+ assert len(indices) == total_size
341
+ indices = indices[self.global_rank : total_size : self.world_size]
342
+ assert len(indices) == self.num_samples
343
+ self.indices = set(indices)
344
+
345
+ data_dict = self.load_data(data_path)
346
+ self.dataset = [data_dict[i] for i in indices]
347
+ self.dataset_len = len(self.dataset)
348
+
349
+
350
+ def get_data_len(self,data_path):
351
+ total_len = 0
352
+ for path in data_path:
353
+ print(f'reading {path}')
354
+ with open(path, mode='r', encoding='utf-8') as jsonl_file:
355
+ for line in jsonl_file:
356
+ now = json.loads(line)
357
+ if now['src'] not in model_alias_mapping:
358
+ model_alias_mapping[now['src']]=now['src']
359
+ now['src'] = model_alias_mapping[now['src']]
360
+ if self.has_mix == False:
361
+ if 'human' in now['src'] and now['src'] != 'human':
362
+ continue
363
+ if now['src'] in self.remove_name:
364
+ continue
365
+ if now['src'] not in self.LLM_name:
366
+ self.LLM_name.add(now['src'])
367
+ total_len+=1
368
+ return total_len
369
+
370
+ def truncate_text(self,text):
371
+
372
+ tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)
373
+ truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
374
+ return truncated_text
375
+
376
+ def merge_dict(self,dict1,dict2):
377
+ for key in dict2:
378
+ dict1[key]=dict2[key]
379
+ return dict1
380
+
381
+ def load_jsonl(self,file_path,total_len):
382
+ out = {}
383
+ cnt=0
384
+ with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
385
+ for line in jsonl_file:
386
+ now = json.loads(line)
387
+ if self.has_mix == False:
388
+ if 'human' in now['src'] and now['src'] != 'human':
389
+ continue
390
+ if now['src'] in self.remove_name:
391
+ continue
392
+ if total_len+cnt in self.indices:
393
+ out[total_len+cnt]=now
394
+ cnt+=1
395
+ return out,cnt
396
+
397
+ def load_data(self,data_path):
398
+ data = {}
399
+ total_len = 0
400
+ for path in data_path:
401
+ print(f'loading {path}')
402
+ now_data,now_len=self.load_jsonl(path,total_len)
403
+ data = self.merge_dict(data,now_data)
404
+ total_len+=now_len
405
+ return data
406
+
407
+ def __len__(self):
408
+ return self.dataset_len
409
+
410
+ def __getitem__(self, idx):
411
+ data = self.dataset[idx]
412
+ text = data['text']
413
+ label = data['label']
414
+ src = self.name2id[model_alias_mapping[data['src']]]
415
+ id = data['id']
416
+
417
+ if random.random()<self.adv_p:
418
+ text = self.truncate_text(text)
419
+ attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\
420
+ HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack])
421
+ text = attack_method(text)
422
+ if self.need_ids:
423
+ return text,int(label),int(src),int(id)
424
+ return text,int(label),int(src)
detree/utils/index.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import pickle
9
+ from typing import List, Tuple
10
+
11
+ import faiss
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ class Indexer(object):
16
+
17
+ def __init__(self, vector_sz, n_subquantizers=0, n_bits=16):
18
+ # if n_subquantizers > 0:
19
+ # self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
20
+ # else:
21
+ self.vector_sz = vector_sz
22
+ self.index = self._create_sharded_index()
23
+ self.index_id_to_db_id = []
24
+ self.label_dict = {}
25
+ # self.index = faiss.IndexFlatIP(vector_sz)
26
+
27
+ # self.index = faiss.index_cpu_to_all_gpus(self.index)
28
+ # #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
29
+ # self.index_id_to_db_id = []
30
+ # self.label_dict = {}
31
+
32
+ def _create_sharded_index(self):
33
+ # Determine the number of available GPUs
34
+ ngpu = faiss.get_num_gpus()
35
+ # Create an IndexShards object with successive_ids=True to keep ids globally unique
36
+ index = faiss.IndexShards(self.vector_sz, True, True)
37
+ # Create a sub-index for each GPU and add it to the IndexShards container
38
+ for i in range(ngpu):
39
+ # Create a standard GPU resource object
40
+ res = faiss.StandardGpuResources()
41
+ # Configure the GPU index
42
+ flat_config = faiss.GpuIndexFlatConfig()
43
+ # flat_config.useFloat16 = True # enable to reduce memory usage with half precision
44
+ flat_config.device = i # assign the GPU device id
45
+ # Create the GPU index
46
+ sub_index = faiss.GpuIndexFlatIP(res, self.vector_sz, flat_config)
47
+ # Add the sub-index into the sharded index
48
+ index.add_shard(sub_index)
49
+ return index
50
+
51
+ def index_data(self, ids, embeddings):
52
+ self._update_id_mapping(ids)
53
+ # embeddings = embeddings
54
+ # if not self.index.is_trained:
55
+ # self.index.train(embeddings)
56
+ self.index.add(embeddings)
57
+
58
+ print(f'Total data indexed {self.index.ntotal}')
59
+
60
+ def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
61
+ # query_vectors = query_vectors
62
+ result = []
63
+ nbatch = (len(query_vectors)-1) // index_batch_size + 1
64
+ for k in tqdm(range(nbatch)):
65
+ start_idx = k*index_batch_size
66
+ end_idx = min((k+1)*index_batch_size, len(query_vectors))
67
+ q = query_vectors[start_idx: end_idx]
68
+ scores, indexes = self.index.search(q, top_docs)
69
+ # convert to external ids
70
+ db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
71
+ db_labels = [[self.label_dict[self.index_id_to_db_id[i]] for i in query_top_idxs] for query_top_idxs in indexes]
72
+ result.extend([(db_ids[i], scores[i],db_labels[i]) for i in range(len(db_ids))])
73
+ return result
74
+
75
+ def serialize(self, dir_path):
76
+ index_file = os.path.join(dir_path, 'index.faiss')
77
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
78
+ print(f'Serializing index to {index_file}, meta data to {meta_file}')
79
+
80
+ faiss.write_index(self.index, index_file)
81
+ with open(meta_file, mode='wb') as f:
82
+ pickle.dump(self.index_id_to_db_id, f)
83
+
84
+ def deserialize_from(self, dir_path):
85
+ index_file = os.path.join(dir_path, 'index.faiss')
86
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
87
+ print(f'Loading index from {index_file}, meta data from {meta_file}')
88
+
89
+ self.index = faiss.read_index(index_file)
90
+ print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
91
+
92
+ with open(meta_file, "rb") as reader:
93
+ self.index_id_to_db_id = pickle.load(reader)
94
+ assert len(
95
+ self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
96
+
97
+ def _update_id_mapping(self, db_ids: List):
98
+ #new_ids = np.array(db_ids, dtype=np.int64)
99
+ #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
100
+ self.index_id_to_db_id.extend(db_ids)
101
+
102
+ def reset(self):
103
+ self.index.reset()
104
+ self.index_id_to_db_id = []
105
+ print(f'Index reset, total data indexed {self.index.ntotal}')
detree/utils/utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import pickle
4
+ import numpy as np
5
+ from sklearn.metrics import precision_recall_curve, auc, roc_auc_score,roc_curve
6
+
7
+ def stable_long_hash(input_string):
8
+ hash_object = hashlib.sha256(input_string.encode())
9
+ hex_digest = hash_object.hexdigest()
10
+ int_hash = int(hex_digest, 16)
11
+ long_long_hash = (int_hash & ((1 << 63) - 1))
12
+ return long_long_hash
13
+
14
+ def load_pkl(path):
15
+ with open(path, 'rb') as f:
16
+ return pickle.load(f)
17
+
18
+
19
+ def save_pkl(obj, path):
20
+ with open(path, 'wb') as f:
21
+ pickle.dump(obj, f)
22
+
23
+
24
+
25
+ def find_top_n(embeddings,n,index,data):
26
+ if len(embeddings.shape) == 1:
27
+ embeddings = embeddings.reshape(1, -1)
28
+ top_ids_and_scores = index.search_knn(embeddings, n)
29
+ data_ans=[]
30
+ for i, (ids, scores) in enumerate(top_ids_and_scores):
31
+ data_now=[]
32
+ for id in ids:
33
+ data_now.append((data[0][int(id)],data[1][int(id)],data[2][int(id)]))
34
+ data_ans.append(data_now)
35
+ return data_ans
36
+
37
+
38
+
39
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
40
+
41
+ def print_line(class_name, metrics, is_header=False):
42
+ if is_header:
43
+ line = f"| {'Class':<10} | " + " | ".join([f"{metric:<10}" for metric in metrics])
44
+ else:
45
+ line = f"| {class_name:<10} | " + " | ".join([f"{metrics[metric]:<10.3f}" for metric in metrics])
46
+ print(line)
47
+ if is_header:
48
+ print('-' * len(line))
49
+
50
+ def calculate_per_class_metrics(classes, ground_truth, predictions):
51
+ # Convert ground truth and predictions to numeric format
52
+ gt_numeric = np.array([int(gt) for gt in ground_truth])
53
+ pred_numeric = np.array([int(pred) for pred in predictions])
54
+
55
+ results = {}
56
+ for i, class_name in enumerate(classes):
57
+ # For each class, calculate the 'vs rest' binary labels
58
+ gt_binary = (gt_numeric == i).astype(int)
59
+ pred_binary = (pred_numeric == i).astype(int)
60
+
61
+ # Calculate metrics, handling cases where a class is not present in predictions or ground truth
62
+ precision = precision_score(gt_binary, pred_binary, zero_division=0)
63
+ recall = recall_score(gt_binary, pred_binary, zero_division=0)
64
+ f1 = f1_score(gt_binary, pred_binary, zero_division=0)
65
+ acc = np.mean(gt_binary == pred_binary)
66
+ # Calculate recall for all other classes as 'rest'
67
+ rest_recall = recall_score(1 - gt_binary, 1 - pred_binary, zero_division=0)
68
+
69
+ results[class_name] = {
70
+ 'Precision': precision,
71
+ 'Recall': recall,
72
+ 'F1 Score': f1,
73
+ 'Accuracy': acc,
74
+ 'Avg Recall (with rest)': (recall + rest_recall) / 2
75
+ }
76
+
77
+ print_line("Metric", results[classes[0]], is_header=True)
78
+ for class_name, metrics in results.items():
79
+ print_line(class_name, metrics)
80
+ overall_metrics = {metric_name: np.mean([metrics[metric_name] for metrics in results.values()]) for metric_name in results[classes[0]].keys()}
81
+ print_line("Overall", overall_metrics)
82
+
83
+ def calculate_metrics(labels, preds):
84
+ acc = accuracy_score(labels, preds)
85
+ precision = precision_score(labels, preds, average='macro')
86
+ recall = recall_score(labels, preds, average='macro')
87
+ f1 = f1_score(labels, preds, average='macro')
88
+ return acc, precision, recall, f1
89
+
90
+ def compute_three_recalls(labels, preds):
91
+ all_n, all_p, tn, tp = 0, 0, 0, 0
92
+ for label, pred in zip(labels, preds):
93
+ if label == '0':
94
+ all_p += 1
95
+ if label == '1':
96
+ all_n += 1
97
+ # Modified condition to treat None in preds as incorrect prediction
98
+ if pred is not None and label == pred == '0':
99
+ tp += 1
100
+ # Modified condition to treat None in preds as incorrect prediction
101
+ if pred is not None and label == pred == '1':
102
+ tn += 1
103
+ if pred is None:
104
+ continue
105
+ machine_rec , human_rec= tp * 100 / all_p if all_p != 0 else 0, tn * 100 / all_n if all_n != 0 else 0
106
+ avg_rec = (human_rec + machine_rec) / 2
107
+ return (human_rec, machine_rec, avg_rec)
108
+
109
+
110
+ def compute_metrics(labels, preds,ids=None):
111
+ # Handling None values in preds as incorrect predictions
112
+ #preds = ['0' if pred is None else pred for pred in preds]
113
+ if ids is not None:
114
+ # Deduplicate labels and predictions for repeated ids
115
+ dict_labels,dict_preds={},{}
116
+ for i in range(len(ids)):
117
+ dict_labels[ids[i]]=labels[i]
118
+ dict_preds[ids[i]]=preds[i]
119
+ labels=list(dict_labels.values())
120
+ preds=list(dict_preds.values())
121
+
122
+ human_rec, machine_rec, avg_rec = compute_three_recalls(labels, preds)
123
+ acc = accuracy_score(labels, preds)
124
+ precision = precision_score(labels, preds, pos_label='1')
125
+ recall = recall_score(labels, preds, pos_label='1')
126
+ f1 = f1_score(labels, preds, pos_label='1')
127
+ # return human_rec, machine_rec, avg_rec
128
+ return (human_rec, machine_rec, avg_rec, acc, precision, recall, f1)
129
+
130
+ def evaluate_max_f1_metrics(test_labels, y_score):
131
+ test_labels = np.array(test_labels)
132
+ y_score = np.array(y_score)
133
+
134
+ auroc = roc_auc_score(test_labels, y_score)
135
+ precision, recall, thresholds = precision_recall_curve(test_labels, y_score, pos_label=1)
136
+ pr_auc = auc(recall, precision)
137
+ epsilon = 1e-6
138
+ f1_scores = 2 * precision * recall / (precision + recall+epsilon)
139
+ best_index = f1_scores.argmax()
140
+ best_f1 = f1_scores[best_index]
141
+ best_precision = precision[best_index]
142
+ best_recall = recall[best_index]
143
+
144
+ threshold = thresholds[best_index] if best_index < len(thresholds) else 1.0
145
+ y_pred_max_f1 = (y_score >= threshold).astype(int)
146
+
147
+ acc = (y_pred_max_f1 == test_labels).mean()
148
+ tp = sum((y_pred_max_f1 == 1) & (test_labels == 1))
149
+ fn = sum((y_pred_max_f1 == 0) & (test_labels == 1))
150
+ fp = sum((y_pred_max_f1 == 1) & (test_labels == 0))
151
+ tn = sum((y_pred_max_f1 == 0) & (test_labels == 0))
152
+
153
+ pos_recall = tp / (tp + fn + epsilon) # recall for the positive class
154
+ neg_recall = tn / (tn + fp + epsilon) # recall for the negative class
155
+ avg_recall = (pos_recall + neg_recall) / 2 # average recall across classes
156
+
157
+ metric = {'auroc': auroc, 'pr_auc': pr_auc, 'F1': best_f1, 'Precision': best_precision,\
158
+ 'Recall': best_recall, 'threshold': threshold, 'acc': acc, 'avg_recall': avg_recall,\
159
+ 'pos_recall': pos_recall, 'neg_recall': neg_recall}
160
+ return metric
161
+
162
+ def evaluate_metrics(test_labels, y_score, threshold_param=-1,target_fpr = 0.05):
163
+ if isinstance(test_labels, list):
164
+ test_labels = np.array(test_labels)
165
+ if isinstance(y_score, list):
166
+ y_score = np.array(y_score)
167
+
168
+ if threshold_param != -1:
169
+ if not (0 <= threshold_param <= 1):
170
+ raise ValueError("Threshold must be between 0 and 1.")
171
+
172
+ auroc = roc_auc_score(test_labels, y_score)
173
+
174
+ precision, recall, thresholds = precision_recall_curve(test_labels, y_score, pos_label=1)
175
+ pr_auc = auc(recall, precision)
176
+
177
+ epsilon = 1e-6
178
+ f1_scores = 2 * precision * recall / (precision + recall + epsilon)
179
+
180
+
181
+ if threshold_param == -1:
182
+ best_index = f1_scores.argmax()
183
+ F1 = f1_scores[best_index]
184
+ Precision = precision[best_index]
185
+ Recall = recall[best_index]
186
+ threshold = thresholds[best_index] if best_index < len(thresholds) else 1.0
187
+ else:
188
+ threshold = threshold_param
189
+ index = np.where(thresholds >= threshold)[0][0]
190
+ Precision = precision[index]
191
+ Recall = recall[index]
192
+ F1 = f1_scores[index]
193
+
194
+
195
+ y_pred = (y_score >= threshold).astype(int)
196
+ acc = (y_pred == test_labels).mean()
197
+
198
+ tp = ((y_pred == 1) & (test_labels == 1)).sum()
199
+ fn = ((y_pred == 0) & (test_labels == 1)).sum()
200
+ fp = ((y_pred == 1) & (test_labels == 0)).sum()
201
+ tn = ((y_pred == 0) & (test_labels == 0)).sum()
202
+
203
+ pos_recall = tp / (tp + fn + epsilon) # TPR
204
+ neg_recall = tn / (tn + fp + epsilon) # TNR
205
+ avg_recall = (pos_recall + neg_recall) / 2
206
+
207
+ fpr, tpr, thds = roc_curve(test_labels, y_score)
208
+ if len(fpr) > 0 and len(tpr) > 0:
209
+ idx = np.argmin(np.abs(fpr - target_fpr))
210
+ tpr_at_fpr = tpr[idx]
211
+ tpr_at_fpr_threshold = thds[idx]
212
+ else:
213
+ tpr_at_fpr = 0.0
214
+
215
+ metric = {'auroc': auroc, 'pr_auc': pr_auc, 'F1': F1, 'Precision': Precision,'Recall': Recall,\
216
+ 'threshold': threshold, 'acc': acc, 'avg_recall': avg_recall,'pos_recall': pos_recall,\
217
+ 'neg_recall': neg_recall, 'tpr_at_fpr': tpr_at_fpr, 'tpr_at_fpr_threshold': tpr_at_fpr_threshold}
218
+
219
+ return metric
220
+ # return (auroc, pr_auc, best_f1, best_precision, best_recall, threshold,
221
+ # acc, avg_recall, pos_recall, neg_recall, tpr_at_fpr5)
222
+
223
+
224
+ def load_datapath(path,include_adversarial=False,dataset_name='all',attack_type='all'):
225
+ data_path = {'train':[],'valid':[],'test':[]}
226
+ if dataset_name=='all':
227
+ datasets = os.listdir(path)
228
+ elif dataset_name=='M4':
229
+ datasets = ['M4_monolingual','M4_multilingual']
230
+ elif dataset_name=='RAID_all':
231
+ datasets = ['RAID','RAID_extra']
232
+ else:
233
+ datasets = [dataset_name]
234
+ for dataset in datasets:
235
+ dataset_path = os.path.join(path,dataset)
236
+ if attack_type!='all':
237
+ dataset_path_list = [pth for pth in os.listdir(dataset_path) if attack_type in pth]
238
+ else:
239
+ dataset_path_list = os.listdir(dataset_path)
240
+ for adv in dataset_path_list:
241
+ if include_adversarial==False and 'no_attack' not in adv:
242
+ continue
243
+ adv_path = os.path.join(dataset_path,adv)
244
+ for data in os.listdir(adv_path):
245
+ if 'train.' in data:
246
+ data_path['train'].append(os.path.join(adv_path,data))
247
+ elif 'test.' in data:
248
+ data_path['test'].append(os.path.join(adv_path,data))
249
+ elif 'valid.' in data:
250
+ data_path['valid'].append(os.path.join(adv_path,data))
251
+ return data_path