| |
| |
| |
| |
| import argparse |
| import json |
| import logging |
| import random |
| import re |
| import sys |
| import time |
| import nltk |
| import numpy as np |
| import torch |
| from tqdm import tqdm |
|
|
| import os |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
|
|
| def bt_translation(src, browser): |
| zh2en_url = f'https://translate.google.com/?hl=zh&sl=en&tl=zh-CN&text={src}&op=translate' |
| browser.get(zh2en_url) |
| time.sleep(random.randint(1, 2)) |
| browser.find_element_by_xpath( |
| '//*[@id="yDmH0d"]/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[1]/span/span/div/div[2]/div[1]').send_keys( |
| src) |
| browser.refresh() |
| |
| time.sleep(random.randint(2, 3)) |
| text = browser.find_element_by_xpath( |
| '/html/body/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[2]/div/div[8]/div/div[1]/span[1]').text |
| en_text = text.replace("翻譯搜尋結果\n", "").replace("\n", "") |
| en2zh_url = f'https://translate.google.com/?hl=zh&sl=zh-CN&tl=en&text={en_text}&op=translate' |
| browser.get(en2zh_url) |
| time.sleep(random.randint(1, 2)) |
| browser.refresh() |
| time.sleep(random.randint(2, 3)) |
| text = browser.find_element_by_xpath( |
| '/html/body/c-wiz/div/div[2]/c-wiz/div[2]/c-wiz/div[1]/div[2]/div[3]/c-wiz[2]/div/div[8]/div/div[1]/span[1]').text |
| tgt = text.replace("翻譯搜尋結果\n", "").replace("\n", "") |
| return tgt |
|
|
|
|
| def read_data(json_path): |
| with open(json_path, 'r', encoding="utf-8") as f: |
| data = json.load(f) |
| return data |
|
|
|
|
| def count_sentences_in_paragraph(paragraph): |
| sentences = nltk.sent_tokenize(paragraph) |
| return len(sentences) |
|
|
|
|
| def save_json_data(data, path): |
| with open(path, "w", encoding="utf-8") as outfile: |
| json.dump(data, outfile, ensure_ascii=False, indent=4) |
|
|
|
|
| def replace_line_breaks(s): |
| s = re.sub('\n', ' ', s) |
| return s |
|
|
|
|
| def truncate_to_last_sentence(s): |
| |
| last_period = s.rfind('.') or s.rfind('!') or s.rfind('?') |
| |
| if last_period != -1: |
| s = s[:last_period + 1] |
| return s |
|
|
|
|
| def check_paragraphs(texts): |
| if count_sentences_in_paragraph(texts) >= 4: |
| return True |
| else: |
| return False |
|
|
| def load_data(input_file): |
| data_file = f"{input_file}.raw_data.json" |
| with open(data_file, "r") as fin: |
| data = json.load(fin) |
| print(f"Raw data loaded from {data_file}") |
| return data |
|
|
| def save_data(output_file, data): |
| |
| data_file = f"{output_file}.raw_data.json" |
| with open(data_file, "w") as fout: |
| json.dump(data, fout, indent=4) |
| print(f"Raw data written into {data_file}") |
|
|
|
|
| def run(args): |
| |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| |
| data = load_data(args.input_path) |
|
|
| |
| if "perturbation" in args.method: |
| from textattack.augmentation import TextBuggerAugmenter |
| from textattack.augmentation import TextFoolerAugmenter |
| from textattack.augmentation import DeepWordBugAugmenter |
| word_augmenter = TextFoolerAugmenter() |
| character_augmenter = DeepWordBugAugmenter() |
| word_character_augmenter = TextBuggerAugmenter() |
|
|
| human_key = "original" |
| llm_key = 'sampled' |
|
|
| n_samples = len(data) |
| for i in tqdm(range(n_samples)): |
| human = data[human_key][i] |
| llm = data[llm_key][i] |
|
|
| |
| if "perturbation" in args.method: |
| humans = count_sentences_in_paragraph(human) |
| llms = count_sentences_in_paragraph(llm) |
| for attack in ["perturbation_character", "perturbation_word", "perturbation_sent"]: |
| if attack == "perturbation_character": |
| augmenter = character_augmenter |
| elif attack == "perturbation_word": |
| augmenter = word_augmenter |
| elif attack == "perturbation_sent": |
| augmenter = word_character_augmenter |
| else: |
| raise ValueError(f"{attack} is not in perturbation_attacks") |
|
|
| try: |
| |
| |
| |
| |
| |
| |
|
|
| final_data = [] |
| for d in range(llms): |
| final_data.append(augmenter.augment(d)[0]) |
| llm_result = ' '.join(final_data) |
| data[llm_key][i] = llm_result |
| logging.info(f"{attack} llm finished") |
|
|
| except Exception as e: |
| logging.info(f"error: {e}") |
| pass |
| |
| save_data(args.output_file, data) |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_path', required=False, default="./exp_main/data/xsum_opt-2.7b", type=str) |
| parser.add_argument('--output_path', required=False, default="./exp_main/data/xsum_opt-2.7b", type=str) |
| parser.add_argument('--method', default="perturbation_sent", type=str, choices=["perturbation_character", "perturbation_word", "perturbation_sent"], required=False) |
| parser.add_argument('--seed', default=2023, type=int, required=False) |
| args = parser.parse_args() |
| run(args) |
|
|