Spaces:
Running
Running
Upload 3 files
Browse files- detree/utils/dataset.py +424 -0
- detree/utils/index.py +105 -0
- detree/utils/utils.py +251 -0
detree/utils/dataset.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from .adversarial.alter_number import AlterNumbersAttack
|
| 8 |
+
from .adversarial.alternative_spelling import AlternativeSpellingAttack
|
| 9 |
+
from .adversarial.article_deletion import ArticleDeletionAttack
|
| 10 |
+
from .adversarial.homoglyph import HomoglyphAttack
|
| 11 |
+
from .adversarial.insert_paragraphs import InsertParagraphsAttack
|
| 12 |
+
from .adversarial.misspelling import MisspellingAttack
|
| 13 |
+
from .adversarial.upper_lower import UpperLowerFlipAttack
|
| 14 |
+
from .adversarial.whitespace import WhiteSpaceAttack
|
| 15 |
+
from .adversarial.zero_width_space import ZeroWidthSpaceAttack
|
| 16 |
+
|
| 17 |
+
model_alias_mapping = {
|
| 18 |
+
'chatgpt': 'chatgpt',
|
| 19 |
+
'ChatGPT': 'chatgpt',
|
| 20 |
+
'chatGPT': 'chatgpt',
|
| 21 |
+
'gpt-3.5-trubo': 'gpt-3.5-trubo',
|
| 22 |
+
'GPT4': 'gpt4',
|
| 23 |
+
'gpt4': 'gpt4',
|
| 24 |
+
'text-davinci-002': 'text-davinci-002',
|
| 25 |
+
'text-davinci-003': 'text-davinci-003',
|
| 26 |
+
'davinci': 'text-davinci',
|
| 27 |
+
'gpt1': 'gpt1',
|
| 28 |
+
'gpt2_pytorch': 'gpt2-pytorch',
|
| 29 |
+
'gpt2_large': 'gpt2-large',
|
| 30 |
+
'gpt2_small': 'gpt2-small',
|
| 31 |
+
'gpt2_medium': 'gpt2-medium',
|
| 32 |
+
'gpt2-xl': 'gpt2-xl',
|
| 33 |
+
'GPT2-XL': 'gpt2-xl',
|
| 34 |
+
'gpt2_xl': 'gpt2-xl',
|
| 35 |
+
'gpt2': 'gpt2-xl',
|
| 36 |
+
'gpt3': 'gpt3',
|
| 37 |
+
'GROVER_base': 'grover_base',
|
| 38 |
+
'grover_base': 'grover_base',
|
| 39 |
+
'grover_large': 'grover_large',
|
| 40 |
+
'grover_mega': 'grover_mega',
|
| 41 |
+
'llama2-fine-tuned': 'llama2',
|
| 42 |
+
'opt_125m': 'opt_125m',
|
| 43 |
+
'opt_1.3b': 'opt_1.3b',
|
| 44 |
+
'opt_2.7b': 'opt_2.7b',
|
| 45 |
+
'opt_6.7b': 'opt_6.7b',
|
| 46 |
+
'opt_13b': 'opt_13b',
|
| 47 |
+
'opt_30b': 'opt_30b',
|
| 48 |
+
'opt_350m': 'opt_350m',
|
| 49 |
+
'opt_iml_max_1.3b': 'opt_iml_max_1.3b',
|
| 50 |
+
'opt_iml_30b': 'opt_iml_30b',
|
| 51 |
+
'flan_t5_small': 'flan_t5_small',
|
| 52 |
+
'flan_t5_base': 'flan_t5_base',
|
| 53 |
+
'flan_t5_large': 'flan_t5_large',
|
| 54 |
+
'flan_t5_xl': 'flan_t5_xl',
|
| 55 |
+
'flan_t5_xxl': 'flan_t5_xxl',
|
| 56 |
+
'flan_t5': 'flan_t5_xxl',
|
| 57 |
+
'dolly': 'dolly',
|
| 58 |
+
'GLM130B': 'GLM130B',
|
| 59 |
+
'bloom_7b': 'bloom_7b',
|
| 60 |
+
'bloomz': 'bloomz',
|
| 61 |
+
't0_3b': 't0_3b',
|
| 62 |
+
't0_11b': 't0_11b',
|
| 63 |
+
'gpt_neox': 'gpt_neox',
|
| 64 |
+
'xlm': 'xlm',
|
| 65 |
+
'xlnet_large': 'xlnet_large',
|
| 66 |
+
'xlnet_base': 'xlnet_base',
|
| 67 |
+
'cohere': 'cohere',
|
| 68 |
+
'ctrl': 'ctrl',
|
| 69 |
+
'pplm_gpt2': 'pplm_gpt2',
|
| 70 |
+
'pplm_distil': 'pplm_distil',
|
| 71 |
+
'fair_wmt19': 'fair_wmt19',
|
| 72 |
+
'fair_wmt20': 'fair_wmt20',
|
| 73 |
+
'glm130b': 'GLM130B',
|
| 74 |
+
'jais-30b': 'jais',
|
| 75 |
+
'transfo_xl': 'transfo_xl',
|
| 76 |
+
'7B': '7B',
|
| 77 |
+
'13B': '13B',
|
| 78 |
+
'65B': '65B',
|
| 79 |
+
'30B': '30B',
|
| 80 |
+
'gpt_j': 'gpt_j',
|
| 81 |
+
'mpt': 'mpt',
|
| 82 |
+
'mpt-chat': 'mpt-chat',
|
| 83 |
+
'llama-chat': 'llama-chat',
|
| 84 |
+
'mistral': 'mistral',
|
| 85 |
+
'mistral-chat': 'mistral-chat',
|
| 86 |
+
'cohere-chat': 'cohere-chat',
|
| 87 |
+
'human': 'human',
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_datapath(path,include_adversarial=False,dataset_name='all',include_attack=False):
|
| 92 |
+
data_path = {'train':[],'test':[]}
|
| 93 |
+
if dataset_name=='all':
|
| 94 |
+
datasets = os.listdir(path)
|
| 95 |
+
elif dataset_name=='M4':
|
| 96 |
+
datasets = ['M4_monolingual','M4_multilingual']
|
| 97 |
+
elif dataset_name=='RAID_all':
|
| 98 |
+
datasets = ['RAID','RAID_extra']
|
| 99 |
+
else:
|
| 100 |
+
datasets = [dataset_name]
|
| 101 |
+
for dataset in datasets:
|
| 102 |
+
dataset_path = os.path.join(path,dataset)
|
| 103 |
+
for adv in os.listdir(dataset_path):
|
| 104 |
+
if include_adversarial==False and 'no_attack' not in adv:
|
| 105 |
+
continue
|
| 106 |
+
if include_attack==False and ('perplexity_attack' in adv or 'synonym' in adv):
|
| 107 |
+
continue
|
| 108 |
+
adv_path = os.path.join(dataset_path,adv)
|
| 109 |
+
for data in os.listdir(adv_path):
|
| 110 |
+
if 'train' in data:
|
| 111 |
+
data_path['train'].append(os.path.join(adv_path,data))
|
| 112 |
+
elif 'test' in data:
|
| 113 |
+
data_path['test'].append(os.path.join(adv_path,data))
|
| 114 |
+
elif 'valid' in data:
|
| 115 |
+
if 'RAID' in dataset:
|
| 116 |
+
data_path['test'].append(os.path.join(adv_path,data))
|
| 117 |
+
else:
|
| 118 |
+
data_path['train'].append(os.path.join(adv_path,data))
|
| 119 |
+
|
| 120 |
+
return data_path
|
| 121 |
+
|
| 122 |
+
class TreeDataset(Dataset):
|
| 123 |
+
def __init__(self,data_path,need_ids=False):
|
| 124 |
+
self.data_path = data_path
|
| 125 |
+
self.need_ids=need_ids
|
| 126 |
+
self.dataset = self.load_data(data_path)
|
| 127 |
+
|
| 128 |
+
LLM_name=set()
|
| 129 |
+
for item in self.dataset:
|
| 130 |
+
name = model_alias_mapping[item['src']]
|
| 131 |
+
LLM_name.add(name)
|
| 132 |
+
self.classes = list(LLM_name)
|
| 133 |
+
self.classes = sorted(self.classes)
|
| 134 |
+
|
| 135 |
+
self.name2id={}
|
| 136 |
+
for i,name in enumerate(self.classes):
|
| 137 |
+
self.name2id[name]=i
|
| 138 |
+
self.human_id = self.name2id['human']
|
| 139 |
+
|
| 140 |
+
def load_jsonl(self,file_path):
|
| 141 |
+
out = []
|
| 142 |
+
add = ''
|
| 143 |
+
if 'paraphrase_by_llm' in file_path:
|
| 144 |
+
add='-paraphrase-qwen7B'
|
| 145 |
+
elif 'paraphrase' in file_path:
|
| 146 |
+
add='-paraphrase-dipper'
|
| 147 |
+
else:
|
| 148 |
+
assert 'no_attack' in file_path,file_path+'file path should contain no_attack or paraphrase'
|
| 149 |
+
|
| 150 |
+
with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
|
| 151 |
+
for line in jsonl_file:
|
| 152 |
+
now = json.loads(line)
|
| 153 |
+
if add != '':
|
| 154 |
+
if 'human' in now['src']:
|
| 155 |
+
continue
|
| 156 |
+
src = model_alias_mapping[now['src']]+add
|
| 157 |
+
if src not in model_alias_mapping:
|
| 158 |
+
model_alias_mapping[src]=src
|
| 159 |
+
now['src']=src
|
| 160 |
+
out.append(now)
|
| 161 |
+
return out
|
| 162 |
+
|
| 163 |
+
def load_data(self,data_path):
|
| 164 |
+
data = []
|
| 165 |
+
for path in data_path:
|
| 166 |
+
if 'no_attack' not in path and 'paraphrase' not in path:
|
| 167 |
+
continue
|
| 168 |
+
print(f'loading {path}')
|
| 169 |
+
data+=self.load_jsonl(path)
|
| 170 |
+
return data
|
| 171 |
+
|
| 172 |
+
def __len__(self):
|
| 173 |
+
return len(self.dataset)
|
| 174 |
+
|
| 175 |
+
def __getitem__(self, idx):
|
| 176 |
+
data_now = self.dataset[idx]
|
| 177 |
+
text = data_now['text']
|
| 178 |
+
label = data_now['label']
|
| 179 |
+
src = model_alias_mapping[data_now['src']]
|
| 180 |
+
src_id = self.name2id[src]
|
| 181 |
+
id = data_now['id']
|
| 182 |
+
if self.need_ids:
|
| 183 |
+
return text,int(label),int(src_id),int(id)
|
| 184 |
+
else:
|
| 185 |
+
return text,int(label),int(src_id)
|
| 186 |
+
|
| 187 |
+
class SCLDataset(Dataset):
|
| 188 |
+
def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True):
|
| 189 |
+
self.data_path = data_path
|
| 190 |
+
self.adv_p = adv_p
|
| 191 |
+
self.need_ids=need_ids
|
| 192 |
+
self.tokenizer = tokenizer
|
| 193 |
+
self.max_length = max_length
|
| 194 |
+
self.has_mix = has_mix
|
| 195 |
+
|
| 196 |
+
self.world_size = fabric.world_size
|
| 197 |
+
self.global_rank = fabric.global_rank
|
| 198 |
+
self.LLM_name=set()
|
| 199 |
+
dataset_len = self.get_data_len(data_path)
|
| 200 |
+
|
| 201 |
+
classes = sorted(list(self.LLM_name))
|
| 202 |
+
if name2id is None:
|
| 203 |
+
self.name2id={}
|
| 204 |
+
for i,name in enumerate(classes):
|
| 205 |
+
self.name2id[name]=i
|
| 206 |
+
else:
|
| 207 |
+
self.name2id = name2id
|
| 208 |
+
for name in classes:
|
| 209 |
+
assert name in self.name2id
|
| 210 |
+
self.classes = classes
|
| 211 |
+
print(f'there are {len(classes)} classes in dataset')
|
| 212 |
+
print(f'the classes are {classes}')
|
| 213 |
+
|
| 214 |
+
self.num_samples = math.ceil(dataset_len / self.world_size)
|
| 215 |
+
total_size = self.num_samples * self.world_size
|
| 216 |
+
indices = list(range(dataset_len))
|
| 217 |
+
padding_size = total_size - len(indices)
|
| 218 |
+
indices += indices[:padding_size]
|
| 219 |
+
assert len(indices) == total_size
|
| 220 |
+
indices = indices[self.global_rank : total_size : self.world_size]
|
| 221 |
+
assert len(indices) == self.num_samples
|
| 222 |
+
self.indices = set(indices)
|
| 223 |
+
|
| 224 |
+
data_dict = self.load_data(data_path)
|
| 225 |
+
self.dataset = [data_dict[i] for i in indices]
|
| 226 |
+
self.dataset_len = len(self.dataset)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_data_len(self,data_path):
|
| 230 |
+
total_len = 0
|
| 231 |
+
for path in data_path:
|
| 232 |
+
print(f'reading {path}')
|
| 233 |
+
with open(path, mode='r', encoding='utf-8') as jsonl_file:
|
| 234 |
+
for line in jsonl_file:
|
| 235 |
+
now = json.loads(line)
|
| 236 |
+
if now['src'] not in model_alias_mapping:
|
| 237 |
+
model_alias_mapping[now['src']]=now['src']
|
| 238 |
+
now['src'] = model_alias_mapping[now['src']]
|
| 239 |
+
if self.has_mix == False:
|
| 240 |
+
if 'human' in now['src'] and now['src'] != 'human':
|
| 241 |
+
continue
|
| 242 |
+
if now['src'] not in self.LLM_name:
|
| 243 |
+
self.LLM_name.add(now['src'])
|
| 244 |
+
total_len+=1
|
| 245 |
+
return total_len
|
| 246 |
+
|
| 247 |
+
def truncate_text(self,text):
|
| 248 |
+
|
| 249 |
+
tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)
|
| 250 |
+
truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
| 251 |
+
return truncated_text
|
| 252 |
+
|
| 253 |
+
def merge_dict(self,dict1,dict2):
|
| 254 |
+
for key in dict2:
|
| 255 |
+
dict1[key]=dict2[key]
|
| 256 |
+
return dict1
|
| 257 |
+
|
| 258 |
+
def load_jsonl(self,file_path,total_len):
|
| 259 |
+
out = {}
|
| 260 |
+
cnt=0
|
| 261 |
+
with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
|
| 262 |
+
for line in jsonl_file:
|
| 263 |
+
now = json.loads(line)
|
| 264 |
+
if self.has_mix == False:
|
| 265 |
+
if 'human' in now['src'] and now['src'] != 'human':
|
| 266 |
+
continue
|
| 267 |
+
if total_len+cnt in self.indices:
|
| 268 |
+
out[total_len+cnt]=now
|
| 269 |
+
cnt+=1
|
| 270 |
+
return out,cnt
|
| 271 |
+
|
| 272 |
+
def load_data(self,data_path):
|
| 273 |
+
data = {}
|
| 274 |
+
total_len = 0
|
| 275 |
+
for path in data_path:
|
| 276 |
+
print(f'loading {path}')
|
| 277 |
+
now_data,now_len=self.load_jsonl(path,total_len)
|
| 278 |
+
data = self.merge_dict(data,now_data)
|
| 279 |
+
total_len+=now_len
|
| 280 |
+
return data
|
| 281 |
+
|
| 282 |
+
def __len__(self):
|
| 283 |
+
return self.dataset_len
|
| 284 |
+
|
| 285 |
+
def __getitem__(self, idx):
|
| 286 |
+
data = self.dataset[idx]
|
| 287 |
+
text = data['text']
|
| 288 |
+
label = data['label']
|
| 289 |
+
src = self.name2id[model_alias_mapping[data['src']]]
|
| 290 |
+
id = data['id']
|
| 291 |
+
|
| 292 |
+
if random.random()<self.adv_p:
|
| 293 |
+
text = self.truncate_text(text)
|
| 294 |
+
attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\
|
| 295 |
+
HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack])
|
| 296 |
+
text = attack_method(text)
|
| 297 |
+
if self.need_ids:
|
| 298 |
+
return text,int(label),int(src),int(id)
|
| 299 |
+
return text,int(label),int(src)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class SCL_RM_Dataset(Dataset):
|
| 303 |
+
def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True,remove_cls=0.9):
|
| 304 |
+
self.data_path = data_path
|
| 305 |
+
self.adv_p = adv_p
|
| 306 |
+
self.need_ids=need_ids
|
| 307 |
+
self.tokenizer = tokenizer
|
| 308 |
+
self.max_length = max_length
|
| 309 |
+
self.has_mix = has_mix
|
| 310 |
+
|
| 311 |
+
self.world_size = fabric.world_size
|
| 312 |
+
self.global_rank = fabric.global_rank
|
| 313 |
+
self.LLM_name=set()
|
| 314 |
+
self.remove_cls = remove_cls
|
| 315 |
+
assert name2id is not None, 'name2id is None, please set name2id'
|
| 316 |
+
self.remove_name = set()
|
| 317 |
+
for name in name2id:
|
| 318 |
+
if random.random()<self.remove_cls and name != 'human':
|
| 319 |
+
self.remove_name.add(name)
|
| 320 |
+
dataset_len = self.get_data_len(data_path)
|
| 321 |
+
|
| 322 |
+
classes = sorted(list(self.LLM_name))
|
| 323 |
+
if name2id is None:
|
| 324 |
+
self.name2id={}
|
| 325 |
+
for i,name in enumerate(classes):
|
| 326 |
+
self.name2id[name]=i
|
| 327 |
+
else:
|
| 328 |
+
self.name2id = name2id
|
| 329 |
+
for name in classes:
|
| 330 |
+
assert name in self.name2id
|
| 331 |
+
self.classes = classes
|
| 332 |
+
print(f'there are {len(classes)} classes in dataset')
|
| 333 |
+
print(f'the classes are {classes}')
|
| 334 |
+
|
| 335 |
+
self.num_samples = math.ceil(dataset_len / self.world_size)
|
| 336 |
+
total_size = self.num_samples * self.world_size
|
| 337 |
+
indices = list(range(dataset_len))
|
| 338 |
+
padding_size = total_size - len(indices)
|
| 339 |
+
indices += indices[:padding_size]
|
| 340 |
+
assert len(indices) == total_size
|
| 341 |
+
indices = indices[self.global_rank : total_size : self.world_size]
|
| 342 |
+
assert len(indices) == self.num_samples
|
| 343 |
+
self.indices = set(indices)
|
| 344 |
+
|
| 345 |
+
data_dict = self.load_data(data_path)
|
| 346 |
+
self.dataset = [data_dict[i] for i in indices]
|
| 347 |
+
self.dataset_len = len(self.dataset)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def get_data_len(self,data_path):
|
| 351 |
+
total_len = 0
|
| 352 |
+
for path in data_path:
|
| 353 |
+
print(f'reading {path}')
|
| 354 |
+
with open(path, mode='r', encoding='utf-8') as jsonl_file:
|
| 355 |
+
for line in jsonl_file:
|
| 356 |
+
now = json.loads(line)
|
| 357 |
+
if now['src'] not in model_alias_mapping:
|
| 358 |
+
model_alias_mapping[now['src']]=now['src']
|
| 359 |
+
now['src'] = model_alias_mapping[now['src']]
|
| 360 |
+
if self.has_mix == False:
|
| 361 |
+
if 'human' in now['src'] and now['src'] != 'human':
|
| 362 |
+
continue
|
| 363 |
+
if now['src'] in self.remove_name:
|
| 364 |
+
continue
|
| 365 |
+
if now['src'] not in self.LLM_name:
|
| 366 |
+
self.LLM_name.add(now['src'])
|
| 367 |
+
total_len+=1
|
| 368 |
+
return total_len
|
| 369 |
+
|
| 370 |
+
def truncate_text(self,text):
|
| 371 |
+
|
| 372 |
+
tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)
|
| 373 |
+
truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
| 374 |
+
return truncated_text
|
| 375 |
+
|
| 376 |
+
def merge_dict(self,dict1,dict2):
|
| 377 |
+
for key in dict2:
|
| 378 |
+
dict1[key]=dict2[key]
|
| 379 |
+
return dict1
|
| 380 |
+
|
| 381 |
+
def load_jsonl(self,file_path,total_len):
|
| 382 |
+
out = {}
|
| 383 |
+
cnt=0
|
| 384 |
+
with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
|
| 385 |
+
for line in jsonl_file:
|
| 386 |
+
now = json.loads(line)
|
| 387 |
+
if self.has_mix == False:
|
| 388 |
+
if 'human' in now['src'] and now['src'] != 'human':
|
| 389 |
+
continue
|
| 390 |
+
if now['src'] in self.remove_name:
|
| 391 |
+
continue
|
| 392 |
+
if total_len+cnt in self.indices:
|
| 393 |
+
out[total_len+cnt]=now
|
| 394 |
+
cnt+=1
|
| 395 |
+
return out,cnt
|
| 396 |
+
|
| 397 |
+
def load_data(self,data_path):
|
| 398 |
+
data = {}
|
| 399 |
+
total_len = 0
|
| 400 |
+
for path in data_path:
|
| 401 |
+
print(f'loading {path}')
|
| 402 |
+
now_data,now_len=self.load_jsonl(path,total_len)
|
| 403 |
+
data = self.merge_dict(data,now_data)
|
| 404 |
+
total_len+=now_len
|
| 405 |
+
return data
|
| 406 |
+
|
| 407 |
+
def __len__(self):
|
| 408 |
+
return self.dataset_len
|
| 409 |
+
|
| 410 |
+
def __getitem__(self, idx):
|
| 411 |
+
data = self.dataset[idx]
|
| 412 |
+
text = data['text']
|
| 413 |
+
label = data['label']
|
| 414 |
+
src = self.name2id[model_alias_mapping[data['src']]]
|
| 415 |
+
id = data['id']
|
| 416 |
+
|
| 417 |
+
if random.random()<self.adv_p:
|
| 418 |
+
text = self.truncate_text(text)
|
| 419 |
+
attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\
|
| 420 |
+
HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack])
|
| 421 |
+
text = attack_method(text)
|
| 422 |
+
if self.need_ids:
|
| 423 |
+
return text,int(label),int(src),int(id)
|
| 424 |
+
return text,int(label),int(src)
|
detree/utils/index.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
import faiss
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
class Indexer(object):
|
| 16 |
+
|
| 17 |
+
def __init__(self, vector_sz, n_subquantizers=0, n_bits=16):
|
| 18 |
+
# if n_subquantizers > 0:
|
| 19 |
+
# self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
|
| 20 |
+
# else:
|
| 21 |
+
self.vector_sz = vector_sz
|
| 22 |
+
self.index = self._create_sharded_index()
|
| 23 |
+
self.index_id_to_db_id = []
|
| 24 |
+
self.label_dict = {}
|
| 25 |
+
# self.index = faiss.IndexFlatIP(vector_sz)
|
| 26 |
+
|
| 27 |
+
# self.index = faiss.index_cpu_to_all_gpus(self.index)
|
| 28 |
+
# #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
|
| 29 |
+
# self.index_id_to_db_id = []
|
| 30 |
+
# self.label_dict = {}
|
| 31 |
+
|
| 32 |
+
def _create_sharded_index(self):
|
| 33 |
+
# Determine the number of available GPUs
|
| 34 |
+
ngpu = faiss.get_num_gpus()
|
| 35 |
+
# Create an IndexShards object with successive_ids=True to keep ids globally unique
|
| 36 |
+
index = faiss.IndexShards(self.vector_sz, True, True)
|
| 37 |
+
# Create a sub-index for each GPU and add it to the IndexShards container
|
| 38 |
+
for i in range(ngpu):
|
| 39 |
+
# Create a standard GPU resource object
|
| 40 |
+
res = faiss.StandardGpuResources()
|
| 41 |
+
# Configure the GPU index
|
| 42 |
+
flat_config = faiss.GpuIndexFlatConfig()
|
| 43 |
+
# flat_config.useFloat16 = True # enable to reduce memory usage with half precision
|
| 44 |
+
flat_config.device = i # assign the GPU device id
|
| 45 |
+
# Create the GPU index
|
| 46 |
+
sub_index = faiss.GpuIndexFlatIP(res, self.vector_sz, flat_config)
|
| 47 |
+
# Add the sub-index into the sharded index
|
| 48 |
+
index.add_shard(sub_index)
|
| 49 |
+
return index
|
| 50 |
+
|
| 51 |
+
def index_data(self, ids, embeddings):
|
| 52 |
+
self._update_id_mapping(ids)
|
| 53 |
+
# embeddings = embeddings
|
| 54 |
+
# if not self.index.is_trained:
|
| 55 |
+
# self.index.train(embeddings)
|
| 56 |
+
self.index.add(embeddings)
|
| 57 |
+
|
| 58 |
+
print(f'Total data indexed {self.index.ntotal}')
|
| 59 |
+
|
| 60 |
+
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
|
| 61 |
+
# query_vectors = query_vectors
|
| 62 |
+
result = []
|
| 63 |
+
nbatch = (len(query_vectors)-1) // index_batch_size + 1
|
| 64 |
+
for k in tqdm(range(nbatch)):
|
| 65 |
+
start_idx = k*index_batch_size
|
| 66 |
+
end_idx = min((k+1)*index_batch_size, len(query_vectors))
|
| 67 |
+
q = query_vectors[start_idx: end_idx]
|
| 68 |
+
scores, indexes = self.index.search(q, top_docs)
|
| 69 |
+
# convert to external ids
|
| 70 |
+
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
|
| 71 |
+
db_labels = [[self.label_dict[self.index_id_to_db_id[i]] for i in query_top_idxs] for query_top_idxs in indexes]
|
| 72 |
+
result.extend([(db_ids[i], scores[i],db_labels[i]) for i in range(len(db_ids))])
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
def serialize(self, dir_path):
|
| 76 |
+
index_file = os.path.join(dir_path, 'index.faiss')
|
| 77 |
+
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
| 78 |
+
print(f'Serializing index to {index_file}, meta data to {meta_file}')
|
| 79 |
+
|
| 80 |
+
faiss.write_index(self.index, index_file)
|
| 81 |
+
with open(meta_file, mode='wb') as f:
|
| 82 |
+
pickle.dump(self.index_id_to_db_id, f)
|
| 83 |
+
|
| 84 |
+
def deserialize_from(self, dir_path):
|
| 85 |
+
index_file = os.path.join(dir_path, 'index.faiss')
|
| 86 |
+
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
| 87 |
+
print(f'Loading index from {index_file}, meta data from {meta_file}')
|
| 88 |
+
|
| 89 |
+
self.index = faiss.read_index(index_file)
|
| 90 |
+
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
|
| 91 |
+
|
| 92 |
+
with open(meta_file, "rb") as reader:
|
| 93 |
+
self.index_id_to_db_id = pickle.load(reader)
|
| 94 |
+
assert len(
|
| 95 |
+
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
|
| 96 |
+
|
| 97 |
+
def _update_id_mapping(self, db_ids: List):
|
| 98 |
+
#new_ids = np.array(db_ids, dtype=np.int64)
|
| 99 |
+
#self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
|
| 100 |
+
self.index_id_to_db_id.extend(db_ids)
|
| 101 |
+
|
| 102 |
+
def reset(self):
|
| 103 |
+
self.index.reset()
|
| 104 |
+
self.index_id_to_db_id = []
|
| 105 |
+
print(f'Index reset, total data indexed {self.index.ntotal}')
|
detree/utils/utils.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.metrics import precision_recall_curve, auc, roc_auc_score,roc_curve
|
| 6 |
+
|
| 7 |
+
def stable_long_hash(input_string):
|
| 8 |
+
hash_object = hashlib.sha256(input_string.encode())
|
| 9 |
+
hex_digest = hash_object.hexdigest()
|
| 10 |
+
int_hash = int(hex_digest, 16)
|
| 11 |
+
long_long_hash = (int_hash & ((1 << 63) - 1))
|
| 12 |
+
return long_long_hash
|
| 13 |
+
|
| 14 |
+
def load_pkl(path):
|
| 15 |
+
with open(path, 'rb') as f:
|
| 16 |
+
return pickle.load(f)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_pkl(obj, path):
|
| 20 |
+
with open(path, 'wb') as f:
|
| 21 |
+
pickle.dump(obj, f)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def find_top_n(embeddings,n,index,data):
|
| 26 |
+
if len(embeddings.shape) == 1:
|
| 27 |
+
embeddings = embeddings.reshape(1, -1)
|
| 28 |
+
top_ids_and_scores = index.search_knn(embeddings, n)
|
| 29 |
+
data_ans=[]
|
| 30 |
+
for i, (ids, scores) in enumerate(top_ids_and_scores):
|
| 31 |
+
data_now=[]
|
| 32 |
+
for id in ids:
|
| 33 |
+
data_now.append((data[0][int(id)],data[1][int(id)],data[2][int(id)]))
|
| 34 |
+
data_ans.append(data_now)
|
| 35 |
+
return data_ans
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
| 40 |
+
|
| 41 |
+
def print_line(class_name, metrics, is_header=False):
|
| 42 |
+
if is_header:
|
| 43 |
+
line = f"| {'Class':<10} | " + " | ".join([f"{metric:<10}" for metric in metrics])
|
| 44 |
+
else:
|
| 45 |
+
line = f"| {class_name:<10} | " + " | ".join([f"{metrics[metric]:<10.3f}" for metric in metrics])
|
| 46 |
+
print(line)
|
| 47 |
+
if is_header:
|
| 48 |
+
print('-' * len(line))
|
| 49 |
+
|
| 50 |
+
def calculate_per_class_metrics(classes, ground_truth, predictions):
|
| 51 |
+
# Convert ground truth and predictions to numeric format
|
| 52 |
+
gt_numeric = np.array([int(gt) for gt in ground_truth])
|
| 53 |
+
pred_numeric = np.array([int(pred) for pred in predictions])
|
| 54 |
+
|
| 55 |
+
results = {}
|
| 56 |
+
for i, class_name in enumerate(classes):
|
| 57 |
+
# For each class, calculate the 'vs rest' binary labels
|
| 58 |
+
gt_binary = (gt_numeric == i).astype(int)
|
| 59 |
+
pred_binary = (pred_numeric == i).astype(int)
|
| 60 |
+
|
| 61 |
+
# Calculate metrics, handling cases where a class is not present in predictions or ground truth
|
| 62 |
+
precision = precision_score(gt_binary, pred_binary, zero_division=0)
|
| 63 |
+
recall = recall_score(gt_binary, pred_binary, zero_division=0)
|
| 64 |
+
f1 = f1_score(gt_binary, pred_binary, zero_division=0)
|
| 65 |
+
acc = np.mean(gt_binary == pred_binary)
|
| 66 |
+
# Calculate recall for all other classes as 'rest'
|
| 67 |
+
rest_recall = recall_score(1 - gt_binary, 1 - pred_binary, zero_division=0)
|
| 68 |
+
|
| 69 |
+
results[class_name] = {
|
| 70 |
+
'Precision': precision,
|
| 71 |
+
'Recall': recall,
|
| 72 |
+
'F1 Score': f1,
|
| 73 |
+
'Accuracy': acc,
|
| 74 |
+
'Avg Recall (with rest)': (recall + rest_recall) / 2
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
print_line("Metric", results[classes[0]], is_header=True)
|
| 78 |
+
for class_name, metrics in results.items():
|
| 79 |
+
print_line(class_name, metrics)
|
| 80 |
+
overall_metrics = {metric_name: np.mean([metrics[metric_name] for metrics in results.values()]) for metric_name in results[classes[0]].keys()}
|
| 81 |
+
print_line("Overall", overall_metrics)
|
| 82 |
+
|
| 83 |
+
def calculate_metrics(labels, preds):
|
| 84 |
+
acc = accuracy_score(labels, preds)
|
| 85 |
+
precision = precision_score(labels, preds, average='macro')
|
| 86 |
+
recall = recall_score(labels, preds, average='macro')
|
| 87 |
+
f1 = f1_score(labels, preds, average='macro')
|
| 88 |
+
return acc, precision, recall, f1
|
| 89 |
+
|
| 90 |
+
def compute_three_recalls(labels, preds):
|
| 91 |
+
all_n, all_p, tn, tp = 0, 0, 0, 0
|
| 92 |
+
for label, pred in zip(labels, preds):
|
| 93 |
+
if label == '0':
|
| 94 |
+
all_p += 1
|
| 95 |
+
if label == '1':
|
| 96 |
+
all_n += 1
|
| 97 |
+
# Modified condition to treat None in preds as incorrect prediction
|
| 98 |
+
if pred is not None and label == pred == '0':
|
| 99 |
+
tp += 1
|
| 100 |
+
# Modified condition to treat None in preds as incorrect prediction
|
| 101 |
+
if pred is not None and label == pred == '1':
|
| 102 |
+
tn += 1
|
| 103 |
+
if pred is None:
|
| 104 |
+
continue
|
| 105 |
+
machine_rec , human_rec= tp * 100 / all_p if all_p != 0 else 0, tn * 100 / all_n if all_n != 0 else 0
|
| 106 |
+
avg_rec = (human_rec + machine_rec) / 2
|
| 107 |
+
return (human_rec, machine_rec, avg_rec)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def compute_metrics(labels, preds,ids=None):
|
| 111 |
+
# Handling None values in preds as incorrect predictions
|
| 112 |
+
#preds = ['0' if pred is None else pred for pred in preds]
|
| 113 |
+
if ids is not None:
|
| 114 |
+
# Deduplicate labels and predictions for repeated ids
|
| 115 |
+
dict_labels,dict_preds={},{}
|
| 116 |
+
for i in range(len(ids)):
|
| 117 |
+
dict_labels[ids[i]]=labels[i]
|
| 118 |
+
dict_preds[ids[i]]=preds[i]
|
| 119 |
+
labels=list(dict_labels.values())
|
| 120 |
+
preds=list(dict_preds.values())
|
| 121 |
+
|
| 122 |
+
human_rec, machine_rec, avg_rec = compute_three_recalls(labels, preds)
|
| 123 |
+
acc = accuracy_score(labels, preds)
|
| 124 |
+
precision = precision_score(labels, preds, pos_label='1')
|
| 125 |
+
recall = recall_score(labels, preds, pos_label='1')
|
| 126 |
+
f1 = f1_score(labels, preds, pos_label='1')
|
| 127 |
+
# return human_rec, machine_rec, avg_rec
|
| 128 |
+
return (human_rec, machine_rec, avg_rec, acc, precision, recall, f1)
|
| 129 |
+
|
| 130 |
+
def evaluate_max_f1_metrics(test_labels, y_score):
|
| 131 |
+
test_labels = np.array(test_labels)
|
| 132 |
+
y_score = np.array(y_score)
|
| 133 |
+
|
| 134 |
+
auroc = roc_auc_score(test_labels, y_score)
|
| 135 |
+
precision, recall, thresholds = precision_recall_curve(test_labels, y_score, pos_label=1)
|
| 136 |
+
pr_auc = auc(recall, precision)
|
| 137 |
+
epsilon = 1e-6
|
| 138 |
+
f1_scores = 2 * precision * recall / (precision + recall+epsilon)
|
| 139 |
+
best_index = f1_scores.argmax()
|
| 140 |
+
best_f1 = f1_scores[best_index]
|
| 141 |
+
best_precision = precision[best_index]
|
| 142 |
+
best_recall = recall[best_index]
|
| 143 |
+
|
| 144 |
+
threshold = thresholds[best_index] if best_index < len(thresholds) else 1.0
|
| 145 |
+
y_pred_max_f1 = (y_score >= threshold).astype(int)
|
| 146 |
+
|
| 147 |
+
acc = (y_pred_max_f1 == test_labels).mean()
|
| 148 |
+
tp = sum((y_pred_max_f1 == 1) & (test_labels == 1))
|
| 149 |
+
fn = sum((y_pred_max_f1 == 0) & (test_labels == 1))
|
| 150 |
+
fp = sum((y_pred_max_f1 == 1) & (test_labels == 0))
|
| 151 |
+
tn = sum((y_pred_max_f1 == 0) & (test_labels == 0))
|
| 152 |
+
|
| 153 |
+
pos_recall = tp / (tp + fn + epsilon) # recall for the positive class
|
| 154 |
+
neg_recall = tn / (tn + fp + epsilon) # recall for the negative class
|
| 155 |
+
avg_recall = (pos_recall + neg_recall) / 2 # average recall across classes
|
| 156 |
+
|
| 157 |
+
metric = {'auroc': auroc, 'pr_auc': pr_auc, 'F1': best_f1, 'Precision': best_precision,\
|
| 158 |
+
'Recall': best_recall, 'threshold': threshold, 'acc': acc, 'avg_recall': avg_recall,\
|
| 159 |
+
'pos_recall': pos_recall, 'neg_recall': neg_recall}
|
| 160 |
+
return metric
|
| 161 |
+
|
| 162 |
+
def evaluate_metrics(test_labels, y_score, threshold_param=-1,target_fpr = 0.05):
|
| 163 |
+
if isinstance(test_labels, list):
|
| 164 |
+
test_labels = np.array(test_labels)
|
| 165 |
+
if isinstance(y_score, list):
|
| 166 |
+
y_score = np.array(y_score)
|
| 167 |
+
|
| 168 |
+
if threshold_param != -1:
|
| 169 |
+
if not (0 <= threshold_param <= 1):
|
| 170 |
+
raise ValueError("Threshold must be between 0 and 1.")
|
| 171 |
+
|
| 172 |
+
auroc = roc_auc_score(test_labels, y_score)
|
| 173 |
+
|
| 174 |
+
precision, recall, thresholds = precision_recall_curve(test_labels, y_score, pos_label=1)
|
| 175 |
+
pr_auc = auc(recall, precision)
|
| 176 |
+
|
| 177 |
+
epsilon = 1e-6
|
| 178 |
+
f1_scores = 2 * precision * recall / (precision + recall + epsilon)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if threshold_param == -1:
|
| 182 |
+
best_index = f1_scores.argmax()
|
| 183 |
+
F1 = f1_scores[best_index]
|
| 184 |
+
Precision = precision[best_index]
|
| 185 |
+
Recall = recall[best_index]
|
| 186 |
+
threshold = thresholds[best_index] if best_index < len(thresholds) else 1.0
|
| 187 |
+
else:
|
| 188 |
+
threshold = threshold_param
|
| 189 |
+
index = np.where(thresholds >= threshold)[0][0]
|
| 190 |
+
Precision = precision[index]
|
| 191 |
+
Recall = recall[index]
|
| 192 |
+
F1 = f1_scores[index]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
y_pred = (y_score >= threshold).astype(int)
|
| 196 |
+
acc = (y_pred == test_labels).mean()
|
| 197 |
+
|
| 198 |
+
tp = ((y_pred == 1) & (test_labels == 1)).sum()
|
| 199 |
+
fn = ((y_pred == 0) & (test_labels == 1)).sum()
|
| 200 |
+
fp = ((y_pred == 1) & (test_labels == 0)).sum()
|
| 201 |
+
tn = ((y_pred == 0) & (test_labels == 0)).sum()
|
| 202 |
+
|
| 203 |
+
pos_recall = tp / (tp + fn + epsilon) # TPR
|
| 204 |
+
neg_recall = tn / (tn + fp + epsilon) # TNR
|
| 205 |
+
avg_recall = (pos_recall + neg_recall) / 2
|
| 206 |
+
|
| 207 |
+
fpr, tpr, thds = roc_curve(test_labels, y_score)
|
| 208 |
+
if len(fpr) > 0 and len(tpr) > 0:
|
| 209 |
+
idx = np.argmin(np.abs(fpr - target_fpr))
|
| 210 |
+
tpr_at_fpr = tpr[idx]
|
| 211 |
+
tpr_at_fpr_threshold = thds[idx]
|
| 212 |
+
else:
|
| 213 |
+
tpr_at_fpr = 0.0
|
| 214 |
+
|
| 215 |
+
metric = {'auroc': auroc, 'pr_auc': pr_auc, 'F1': F1, 'Precision': Precision,'Recall': Recall,\
|
| 216 |
+
'threshold': threshold, 'acc': acc, 'avg_recall': avg_recall,'pos_recall': pos_recall,\
|
| 217 |
+
'neg_recall': neg_recall, 'tpr_at_fpr': tpr_at_fpr, 'tpr_at_fpr_threshold': tpr_at_fpr_threshold}
|
| 218 |
+
|
| 219 |
+
return metric
|
| 220 |
+
# return (auroc, pr_auc, best_f1, best_precision, best_recall, threshold,
|
| 221 |
+
# acc, avg_recall, pos_recall, neg_recall, tpr_at_fpr5)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def load_datapath(path,include_adversarial=False,dataset_name='all',attack_type='all'):
|
| 225 |
+
data_path = {'train':[],'valid':[],'test':[]}
|
| 226 |
+
if dataset_name=='all':
|
| 227 |
+
datasets = os.listdir(path)
|
| 228 |
+
elif dataset_name=='M4':
|
| 229 |
+
datasets = ['M4_monolingual','M4_multilingual']
|
| 230 |
+
elif dataset_name=='RAID_all':
|
| 231 |
+
datasets = ['RAID','RAID_extra']
|
| 232 |
+
else:
|
| 233 |
+
datasets = [dataset_name]
|
| 234 |
+
for dataset in datasets:
|
| 235 |
+
dataset_path = os.path.join(path,dataset)
|
| 236 |
+
if attack_type!='all':
|
| 237 |
+
dataset_path_list = [pth for pth in os.listdir(dataset_path) if attack_type in pth]
|
| 238 |
+
else:
|
| 239 |
+
dataset_path_list = os.listdir(dataset_path)
|
| 240 |
+
for adv in dataset_path_list:
|
| 241 |
+
if include_adversarial==False and 'no_attack' not in adv:
|
| 242 |
+
continue
|
| 243 |
+
adv_path = os.path.join(dataset_path,adv)
|
| 244 |
+
for data in os.listdir(adv_path):
|
| 245 |
+
if 'train.' in data:
|
| 246 |
+
data_path['train'].append(os.path.join(adv_path,data))
|
| 247 |
+
elif 'test.' in data:
|
| 248 |
+
data_path['test'].append(os.path.join(adv_path,data))
|
| 249 |
+
elif 'valid.' in data:
|
| 250 |
+
data_path['valid'].append(os.path.join(adv_path,data))
|
| 251 |
+
return data_path
|