File size: 5,973 Bytes
485127c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
import random
import re
import sys
import time
import nltk
import numpy as np
import torch
from tqdm import tqdm
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def bt_translation(src, browser):
zh2en_url = f'https://translate.google.com/?hl=zh&sl=en&tl=zh-CN&text={src}&op=translate'
browser.get(zh2en_url) # 访问相对应链接 browser.close#关闭浏览器
time.sleep(random.randint(1, 2))
browser.find_element_by_xpath(
'//*[@id="yDmH0d"]/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[1]/span/span/div/div[2]/div[1]').send_keys(
src)
browser.refresh()
# time.sleep(50)
time.sleep(random.randint(2, 3))
text = browser.find_element_by_xpath(
'/html/body/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[2]/div/div[8]/div/div[1]/span[1]').text
en_text = text.replace("翻譯搜尋結果\n", "").replace("\n", "")
en2zh_url = f'https://translate.google.com/?hl=zh&sl=zh-CN&tl=en&text={en_text}&op=translate'
browser.get(en2zh_url) # 访问相对应链接 browser.close#关闭浏览器
time.sleep(random.randint(1, 2))
browser.refresh()
time.sleep(random.randint(2, 3))
text = browser.find_element_by_xpath(
'/html/body/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[2]/div/div[8]/div/div[1]/span[1]').text
tgt = text.replace("翻譯搜尋結果\n", "").replace("\n", "")
return tgt
def read_data(json_path):
with open(json_path, 'r', encoding="utf-8") as f:
data = json.load(f)
return data
def count_sentences_in_paragraph(paragraph):
sentences = nltk.sent_tokenize(paragraph)
return len(sentences)
def save_json_data(data, path):
with open(path, "w", encoding="utf-8") as outfile:
json.dump(data, outfile, ensure_ascii=False, indent=4)
def replace_line_breaks(s):
s = re.sub('\n', ' ', s)
return s
def truncate_to_last_sentence(s):
# 从字符串尾部向前找句号的位置
last_period = s.rfind('.') or s.rfind('!') or s.rfind('?')
# 如果找到句号,就截断到这个位置(包括句号)
if last_period != -1:
s = s[:last_period + 1]
return s
def check_paragraphs(texts):
if count_sentences_in_paragraph(texts) >= 4:
return True
else:
return False
def load_data(input_file):
data_file = f"{input_file}.raw_data.json"
with open(data_file, "r") as fin:
data = json.load(fin)
print(f"Raw data loaded from {data_file}")
return data
def save_data(output_file, data):
# write the data to a json file in the save folder
data_file = f"{output_file}.raw_data.json"
with open(data_file, "w") as fout:
json.dump(data, fout, indent=4)
print(f"Raw data written into {data_file}")
def run(args):
# set seed
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# load data
data = load_data(args.input_path)
# perturbation attacks prepare
if "perturbation" in args.method:
from textattack.augmentation import TextBuggerAugmenter
from textattack.augmentation import TextFoolerAugmenter
from textattack.augmentation import DeepWordBugAugmenter
word_augmenter = TextFoolerAugmenter()
character_augmenter = DeepWordBugAugmenter()
word_character_augmenter = TextBuggerAugmenter()
human_key = "original"
llm_key = 'sampled'
n_samples = len(data)
for i in tqdm(range(n_samples)):
human = data[human_key][i]
llm = data[llm_key][i]
# llms selection
if "perturbation" in args.method:
humans = count_sentences_in_paragraph(human)
llms = count_sentences_in_paragraph(llm)
for attack in ["perturbation_character", "perturbation_word", "perturbation_sent"]:
if attack == "perturbation_character":
augmenter = character_augmenter
elif attack == "perturbation_word":
augmenter = word_augmenter
elif attack == "perturbation_sent":
augmenter = word_character_augmenter
else:
raise ValueError(f"{attack} is not in perturbation_attacks")
try:
# final_data = []
# for d in range(humans):
# final_data.append(augmenter.augment(d)[0])
# human_result = ' '.join(final_data)
# data[human_key][i] = human_result
# logging.info(f"{attack} human finished")
final_data = []
for d in range(llms):
final_data.append(augmenter.augment(d)[0])
llm_result = ' '.join(final_data)
data[llm_key][i] = llm_result
logging.info(f"{attack} llm finished")
except Exception as e:
logging.info(f"error: {e}")
pass
save_data(args.output_file, data)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', required=False, default="./exp_main/data/xsum_opt-2.7b", type=str)
parser.add_argument('--output_path', required=False, default="./exp_main/data/xsum_opt-2.7b", type=str)
parser.add_argument('--method', default="perturbation_sent", type=str, choices=["perturbation_character", "perturbation_word", "perturbation_sent"], required=False)
parser.add_argument('--seed', default=2023, type=int, required=False)
args = parser.parse_args()
run(args)
|