Spaces:
Running
Running
| import json | |
| import math | |
| import os | |
| import random | |
| import torch | |
| from torch.utils.data import Dataset | |
| from .adversarial.alter_number import AlterNumbersAttack | |
| from .adversarial.alternative_spelling import AlternativeSpellingAttack | |
| from .adversarial.article_deletion import ArticleDeletionAttack | |
| from .adversarial.homoglyph import HomoglyphAttack | |
| from .adversarial.insert_paragraphs import InsertParagraphsAttack | |
| from .adversarial.misspelling import MisspellingAttack | |
| from .adversarial.upper_lower import UpperLowerFlipAttack | |
| from .adversarial.whitespace import WhiteSpaceAttack | |
| from .adversarial.zero_width_space import ZeroWidthSpaceAttack | |
| model_alias_mapping = { | |
| 'chatgpt': 'chatgpt', | |
| 'ChatGPT': 'chatgpt', | |
| 'chatGPT': 'chatgpt', | |
| 'gpt-3.5-trubo': 'gpt-3.5-trubo', | |
| 'GPT4': 'gpt4', | |
| 'gpt4': 'gpt4', | |
| 'text-davinci-002': 'text-davinci-002', | |
| 'text-davinci-003': 'text-davinci-003', | |
| 'davinci': 'text-davinci', | |
| 'gpt1': 'gpt1', | |
| 'gpt2_pytorch': 'gpt2-pytorch', | |
| 'gpt2_large': 'gpt2-large', | |
| 'gpt2_small': 'gpt2-small', | |
| 'gpt2_medium': 'gpt2-medium', | |
| 'gpt2-xl': 'gpt2-xl', | |
| 'GPT2-XL': 'gpt2-xl', | |
| 'gpt2_xl': 'gpt2-xl', | |
| 'gpt2': 'gpt2-xl', | |
| 'gpt3': 'gpt3', | |
| 'GROVER_base': 'grover_base', | |
| 'grover_base': 'grover_base', | |
| 'grover_large': 'grover_large', | |
| 'grover_mega': 'grover_mega', | |
| 'llama2-fine-tuned': 'llama2', | |
| 'opt_125m': 'opt_125m', | |
| 'opt_1.3b': 'opt_1.3b', | |
| 'opt_2.7b': 'opt_2.7b', | |
| 'opt_6.7b': 'opt_6.7b', | |
| 'opt_13b': 'opt_13b', | |
| 'opt_30b': 'opt_30b', | |
| 'opt_350m': 'opt_350m', | |
| 'opt_iml_max_1.3b': 'opt_iml_max_1.3b', | |
| 'opt_iml_30b': 'opt_iml_30b', | |
| 'flan_t5_small': 'flan_t5_small', | |
| 'flan_t5_base': 'flan_t5_base', | |
| 'flan_t5_large': 'flan_t5_large', | |
| 'flan_t5_xl': 'flan_t5_xl', | |
| 'flan_t5_xxl': 'flan_t5_xxl', | |
| 'flan_t5': 'flan_t5_xxl', | |
| 'dolly': 'dolly', | |
| 'GLM130B': 'GLM130B', | |
| 'bloom_7b': 'bloom_7b', | |
| 'bloomz': 'bloomz', | |
| 't0_3b': 't0_3b', | |
| 't0_11b': 't0_11b', | |
| 'gpt_neox': 'gpt_neox', | |
| 'xlm': 'xlm', | |
| 'xlnet_large': 'xlnet_large', | |
| 'xlnet_base': 'xlnet_base', | |
| 'cohere': 'cohere', | |
| 'ctrl': 'ctrl', | |
| 'pplm_gpt2': 'pplm_gpt2', | |
| 'pplm_distil': 'pplm_distil', | |
| 'fair_wmt19': 'fair_wmt19', | |
| 'fair_wmt20': 'fair_wmt20', | |
| 'glm130b': 'GLM130B', | |
| 'jais-30b': 'jais', | |
| 'transfo_xl': 'transfo_xl', | |
| '7B': '7B', | |
| '13B': '13B', | |
| '65B': '65B', | |
| '30B': '30B', | |
| 'gpt_j': 'gpt_j', | |
| 'mpt': 'mpt', | |
| 'mpt-chat': 'mpt-chat', | |
| 'llama-chat': 'llama-chat', | |
| 'mistral': 'mistral', | |
| 'mistral-chat': 'mistral-chat', | |
| 'cohere-chat': 'cohere-chat', | |
| 'human': 'human', | |
| } | |
| def load_datapath(path,include_adversarial=False,dataset_name='all',include_attack=False): | |
| data_path = {'train':[],'test':[]} | |
| if dataset_name=='all': | |
| datasets = os.listdir(path) | |
| elif dataset_name=='M4': | |
| datasets = ['M4_monolingual','M4_multilingual'] | |
| elif dataset_name=='RAID_all': | |
| datasets = ['RAID','RAID_extra'] | |
| else: | |
| datasets = [dataset_name] | |
| for dataset in datasets: | |
| dataset_path = os.path.join(path,dataset) | |
| for adv in os.listdir(dataset_path): | |
| if include_adversarial==False and 'no_attack' not in adv: | |
| continue | |
| if include_attack==False and ('perplexity_attack' in adv or 'synonym' in adv): | |
| continue | |
| adv_path = os.path.join(dataset_path,adv) | |
| for data in os.listdir(adv_path): | |
| if 'train' in data: | |
| data_path['train'].append(os.path.join(adv_path,data)) | |
| elif 'test' in data: | |
| data_path['test'].append(os.path.join(adv_path,data)) | |
| elif 'valid' in data: | |
| if 'RAID' in dataset: | |
| data_path['test'].append(os.path.join(adv_path,data)) | |
| else: | |
| data_path['train'].append(os.path.join(adv_path,data)) | |
| return data_path | |
| class TreeDataset(Dataset): | |
| def __init__(self,data_path,need_ids=False): | |
| self.data_path = data_path | |
| self.need_ids=need_ids | |
| self.dataset = self.load_data(data_path) | |
| LLM_name=set() | |
| for item in self.dataset: | |
| name = model_alias_mapping[item['src']] | |
| LLM_name.add(name) | |
| self.classes = list(LLM_name) | |
| self.classes = sorted(self.classes) | |
| self.name2id={} | |
| for i,name in enumerate(self.classes): | |
| self.name2id[name]=i | |
| self.human_id = self.name2id['human'] | |
| def load_jsonl(self,file_path): | |
| out = [] | |
| add = '' | |
| if 'paraphrase_by_llm' in file_path: | |
| add='-paraphrase-qwen7B' | |
| elif 'paraphrase' in file_path: | |
| add='-paraphrase-dipper' | |
| else: | |
| assert 'no_attack' in file_path,file_path+'file path should contain no_attack or paraphrase' | |
| with open(file_path, mode='r', encoding='utf-8') as jsonl_file: | |
| for line in jsonl_file: | |
| now = json.loads(line) | |
| if add != '': | |
| if 'human' in now['src']: | |
| continue | |
| src = model_alias_mapping[now['src']]+add | |
| if src not in model_alias_mapping: | |
| model_alias_mapping[src]=src | |
| now['src']=src | |
| out.append(now) | |
| return out | |
| def load_data(self,data_path): | |
| data = [] | |
| for path in data_path: | |
| if 'no_attack' not in path and 'paraphrase' not in path: | |
| continue | |
| print(f'loading {path}') | |
| data+=self.load_jsonl(path) | |
| return data | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| data_now = self.dataset[idx] | |
| text = data_now['text'] | |
| label = data_now['label'] | |
| src = model_alias_mapping[data_now['src']] | |
| src_id = self.name2id[src] | |
| id = data_now['id'] | |
| if self.need_ids: | |
| return text,int(label),int(src_id),int(id) | |
| else: | |
| return text,int(label),int(src_id) | |
| class SCLDataset(Dataset): | |
| def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True): | |
| self.data_path = data_path | |
| self.adv_p = adv_p | |
| self.need_ids=need_ids | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.has_mix = has_mix | |
| self.world_size = fabric.world_size | |
| self.global_rank = fabric.global_rank | |
| self.LLM_name=set() | |
| dataset_len = self.get_data_len(data_path) | |
| classes = sorted(list(self.LLM_name)) | |
| if name2id is None: | |
| self.name2id={} | |
| for i,name in enumerate(classes): | |
| self.name2id[name]=i | |
| else: | |
| self.name2id = name2id | |
| for name in classes: | |
| assert name in self.name2id | |
| self.classes = classes | |
| print(f'there are {len(classes)} classes in dataset') | |
| print(f'the classes are {classes}') | |
| self.num_samples = math.ceil(dataset_len / self.world_size) | |
| total_size = self.num_samples * self.world_size | |
| indices = list(range(dataset_len)) | |
| padding_size = total_size - len(indices) | |
| indices += indices[:padding_size] | |
| assert len(indices) == total_size | |
| indices = indices[self.global_rank : total_size : self.world_size] | |
| assert len(indices) == self.num_samples | |
| self.indices = set(indices) | |
| data_dict = self.load_data(data_path) | |
| self.dataset = [data_dict[i] for i in indices] | |
| self.dataset_len = len(self.dataset) | |
| def get_data_len(self,data_path): | |
| total_len = 0 | |
| for path in data_path: | |
| print(f'reading {path}') | |
| with open(path, mode='r', encoding='utf-8') as jsonl_file: | |
| for line in jsonl_file: | |
| now = json.loads(line) | |
| if now['src'] not in model_alias_mapping: | |
| model_alias_mapping[now['src']]=now['src'] | |
| now['src'] = model_alias_mapping[now['src']] | |
| if self.has_mix == False: | |
| if 'human' in now['src'] and now['src'] != 'human': | |
| continue | |
| if now['src'] not in self.LLM_name: | |
| self.LLM_name.add(now['src']) | |
| total_len+=1 | |
| return total_len | |
| def truncate_text(self,text): | |
| tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length) | |
| truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True) | |
| return truncated_text | |
| def merge_dict(self,dict1,dict2): | |
| for key in dict2: | |
| dict1[key]=dict2[key] | |
| return dict1 | |
| def load_jsonl(self,file_path,total_len): | |
| out = {} | |
| cnt=0 | |
| with open(file_path, mode='r', encoding='utf-8') as jsonl_file: | |
| for line in jsonl_file: | |
| now = json.loads(line) | |
| if self.has_mix == False: | |
| if 'human' in now['src'] and now['src'] != 'human': | |
| continue | |
| if total_len+cnt in self.indices: | |
| out[total_len+cnt]=now | |
| cnt+=1 | |
| return out,cnt | |
| def load_data(self,data_path): | |
| data = {} | |
| total_len = 0 | |
| for path in data_path: | |
| print(f'loading {path}') | |
| now_data,now_len=self.load_jsonl(path,total_len) | |
| data = self.merge_dict(data,now_data) | |
| total_len+=now_len | |
| return data | |
| def __len__(self): | |
| return self.dataset_len | |
| def __getitem__(self, idx): | |
| data = self.dataset[idx] | |
| text = data['text'] | |
| label = data['label'] | |
| src = self.name2id[model_alias_mapping[data['src']]] | |
| id = data['id'] | |
| if random.random()<self.adv_p: | |
| text = self.truncate_text(text) | |
| attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\ | |
| HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack]) | |
| text = attack_method(text) | |
| if self.need_ids: | |
| return text,int(label),int(src),int(id) | |
| return text,int(label),int(src) | |
| class SCL_RM_Dataset(Dataset): | |
| def __init__(self, data_path,fabric,tokenizer,need_ids=False,adv_p=0.5,max_length=530,name2id=None,has_mix=True,remove_cls=0.9): | |
| self.data_path = data_path | |
| self.adv_p = adv_p | |
| self.need_ids=need_ids | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.has_mix = has_mix | |
| self.world_size = fabric.world_size | |
| self.global_rank = fabric.global_rank | |
| self.LLM_name=set() | |
| self.remove_cls = remove_cls | |
| assert name2id is not None, 'name2id is None, please set name2id' | |
| self.remove_name = set() | |
| for name in name2id: | |
| if random.random()<self.remove_cls and name != 'human': | |
| self.remove_name.add(name) | |
| dataset_len = self.get_data_len(data_path) | |
| classes = sorted(list(self.LLM_name)) | |
| if name2id is None: | |
| self.name2id={} | |
| for i,name in enumerate(classes): | |
| self.name2id[name]=i | |
| else: | |
| self.name2id = name2id | |
| for name in classes: | |
| assert name in self.name2id | |
| self.classes = classes | |
| print(f'there are {len(classes)} classes in dataset') | |
| print(f'the classes are {classes}') | |
| self.num_samples = math.ceil(dataset_len / self.world_size) | |
| total_size = self.num_samples * self.world_size | |
| indices = list(range(dataset_len)) | |
| padding_size = total_size - len(indices) | |
| indices += indices[:padding_size] | |
| assert len(indices) == total_size | |
| indices = indices[self.global_rank : total_size : self.world_size] | |
| assert len(indices) == self.num_samples | |
| self.indices = set(indices) | |
| data_dict = self.load_data(data_path) | |
| self.dataset = [data_dict[i] for i in indices] | |
| self.dataset_len = len(self.dataset) | |
| def get_data_len(self,data_path): | |
| total_len = 0 | |
| for path in data_path: | |
| print(f'reading {path}') | |
| with open(path, mode='r', encoding='utf-8') as jsonl_file: | |
| for line in jsonl_file: | |
| now = json.loads(line) | |
| if now['src'] not in model_alias_mapping: | |
| model_alias_mapping[now['src']]=now['src'] | |
| now['src'] = model_alias_mapping[now['src']] | |
| if self.has_mix == False: | |
| if 'human' in now['src'] and now['src'] != 'human': | |
| continue | |
| if now['src'] in self.remove_name: | |
| continue | |
| if now['src'] not in self.LLM_name: | |
| self.LLM_name.add(now['src']) | |
| total_len+=1 | |
| return total_len | |
| def truncate_text(self,text): | |
| tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length) | |
| truncated_text = self.tokenizer.decode(tokens, skip_special_tokens=True) | |
| return truncated_text | |
| def merge_dict(self,dict1,dict2): | |
| for key in dict2: | |
| dict1[key]=dict2[key] | |
| return dict1 | |
| def load_jsonl(self,file_path,total_len): | |
| out = {} | |
| cnt=0 | |
| with open(file_path, mode='r', encoding='utf-8') as jsonl_file: | |
| for line in jsonl_file: | |
| now = json.loads(line) | |
| if self.has_mix == False: | |
| if 'human' in now['src'] and now['src'] != 'human': | |
| continue | |
| if now['src'] in self.remove_name: | |
| continue | |
| if total_len+cnt in self.indices: | |
| out[total_len+cnt]=now | |
| cnt+=1 | |
| return out,cnt | |
| def load_data(self,data_path): | |
| data = {} | |
| total_len = 0 | |
| for path in data_path: | |
| print(f'loading {path}') | |
| now_data,now_len=self.load_jsonl(path,total_len) | |
| data = self.merge_dict(data,now_data) | |
| total_len+=now_len | |
| return data | |
| def __len__(self): | |
| return self.dataset_len | |
| def __getitem__(self, idx): | |
| data = self.dataset[idx] | |
| text = data['text'] | |
| label = data['label'] | |
| src = self.name2id[model_alias_mapping[data['src']]] | |
| id = data['id'] | |
| if random.random()<self.adv_p: | |
| text = self.truncate_text(text) | |
| attack_method = random.choice([AlterNumbersAttack,AlternativeSpellingAttack,ArticleDeletionAttack,\ | |
| HomoglyphAttack,InsertParagraphsAttack,MisspellingAttack,UpperLowerFlipAttack,WhiteSpaceAttack,ZeroWidthSpaceAttack]) | |
| text = attack_method(text) | |
| if self.need_ids: | |
| return text,int(label),int(src),int(id) | |
| return text,int(label),int(src) |