Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/bnc_spoken_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/childes_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/gutenberg_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/open_subtitles_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/simple_wiki_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/switchboard_unaffected.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/childes_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test +0 -0
- data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/switchboard_unaffected_sents.test +0 -0
- data/perturb.py +359 -0
- data/perturb.sh +35 -0
- data/perturb_llama.py +361 -0
- data/perturb_model.sh +40 -0
- data/perturb_qwen.py +361 -0
- data/tag.py +153 -0
- data/tag_1.py +166 -0
- data/tag_distributed.py +106 -0
- data/tag_single.py +140 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1000.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-10000.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-11500.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1500.csv +2 -0
- perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-2000.csv +2 -0
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/childes_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected/switchboard_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_reverse_control/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/childes_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected/switchboard_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local3/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/bnc_spoken_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/childes_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/gutenberg_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/open_subtitles_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/simple_wiki_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected/switchboard_unaffected.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/bnc_spoken_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/childes_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/gutenberg_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/open_subtitles_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/simple_wiki_unaffected_sents.test
ADDED
|
File without changes
|
data/Perturbed_data/Llama-3.2-3B/babylm_shuffle_local5/babylm_test_unaffected_sents/switchboard_unaffected_sents.test
ADDED
|
File without changes
|
data/perturb.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# perturb.py
|
| 2 |
+
# Author: Julie Kallini
|
| 3 |
+
|
| 4 |
+
# For importing utils
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
from utils import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
| 9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
| 10 |
+
from glob import glob
|
| 11 |
+
import numpy as np
|
| 12 |
+
import itertools
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import tqdm
|
| 16 |
+
import argparse
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
| 21 |
+
"""Compare lines of two files after splitting them."""
|
| 22 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
| 23 |
+
for line1, line2 in zip(file1, file2):
|
| 24 |
+
# Split each line and compare the resulting lists
|
| 25 |
+
res1 = [i for i in line1.split() if int(
|
| 26 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 27 |
+
res2 = [i for i in line2.split() if int(
|
| 28 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 29 |
+
if res1 != res2:
|
| 30 |
+
print(line1)
|
| 31 |
+
print(line2)
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
# Check if one file has more lines than the other
|
| 35 |
+
if file1.readline() or file2.readline():
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
perturbation_pairs_3pres = [
|
| 42 |
+
("0tokens", "4tokens"),
|
| 43 |
+
("0tokens", "4words"),
|
| 44 |
+
("4tokens", "4words"),
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
| 48 |
+
|
| 49 |
+
test_data = itertools.product(
|
| 50 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
| 51 |
+
|
| 52 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
| 53 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
| 57 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
| 58 |
+
|
| 59 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 60 |
+
|
| 61 |
+
if split in ("100M", "10M"):
|
| 62 |
+
filename = f"{genre}.train"
|
| 63 |
+
elif split == "test_affected":
|
| 64 |
+
filename = f"{genre}_affected.test"
|
| 65 |
+
elif split == "test_unaffected":
|
| 66 |
+
filename = f"{genre}_unaffected.test"
|
| 67 |
+
elif split == "dev":
|
| 68 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
| 69 |
+
|
| 70 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
| 71 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
| 72 |
+
|
| 73 |
+
#Yj: compare two files in two paths
|
| 74 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
| 75 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
| 79 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 80 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 81 |
+
for line1, line2 in zip(file1, file2):
|
| 82 |
+
# Split each line and compare the resulting lists
|
| 83 |
+
line1_tokens = line1.split()
|
| 84 |
+
line2_tokens = line2.split()
|
| 85 |
+
|
| 86 |
+
# Get REV marker index
|
| 87 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
| 88 |
+
|
| 89 |
+
# Make sure tokens up to and including the marker are all the same
|
| 90 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
# Make sure reversal of rest of string is equal to identity
|
| 94 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
| 95 |
+
line1_tokens_rev.reverse()
|
| 96 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
# Check if one file has more lines than the other
|
| 100 |
+
if file1.readline() or file2.readline():
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
perturbation_pairs_reversal = [
|
| 107 |
+
("reversal", "reversal_identity"),
|
| 108 |
+
]
|
| 109 |
+
# Yj: 针对反转扰动对进行组合测试
|
| 110 |
+
|
| 111 |
+
test_data = itertools.product(
|
| 112 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 113 |
+
|
| 114 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 115 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
| 116 |
+
|
| 117 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 118 |
+
|
| 119 |
+
if split in ("100M", "10M"):
|
| 120 |
+
filename = f"{genre}.train"
|
| 121 |
+
elif split == "test_affected":
|
| 122 |
+
filename = f"{genre}_affected.test"
|
| 123 |
+
elif split == "test_unaffected":
|
| 124 |
+
filename = f"{genre}_unaffected.test"
|
| 125 |
+
elif split == "dev":
|
| 126 |
+
filename = f"{genre}.dev"
|
| 127 |
+
|
| 128 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 129 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 130 |
+
|
| 131 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
| 132 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
| 136 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 137 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 138 |
+
for line1, line2 in zip(file1, file2):
|
| 139 |
+
# Split each line and compare the resulting lists
|
| 140 |
+
line1_tokens = set(line1.split())
|
| 141 |
+
line2_tokens = set(line2.split())
|
| 142 |
+
if line1_tokens != line2_tokens:
|
| 143 |
+
print(line1.split())
|
| 144 |
+
print(line2.split())
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
# Check if one file has more lines than the other
|
| 148 |
+
if file1.readline() or file2.readline():
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
perturbation_pairs_reversal = [
|
| 155 |
+
("determiner_swap", "determiner_swap_identity"),
|
| 156 |
+
]
|
| 157 |
+
test_data = itertools.product(
|
| 158 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 159 |
+
|
| 160 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 161 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
| 162 |
+
|
| 163 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 164 |
+
|
| 165 |
+
if split in ("100M", "10M"):
|
| 166 |
+
filename = f"{genre}.train"
|
| 167 |
+
elif split == "test_affected":
|
| 168 |
+
filename = f"{genre}_affected.test"
|
| 169 |
+
elif split == "test_unaffected":
|
| 170 |
+
filename = f"{genre}_unaffected.test"
|
| 171 |
+
elif split == "dev":
|
| 172 |
+
filename = f"{genre}.dev"
|
| 173 |
+
|
| 174 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 175 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 176 |
+
|
| 177 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
| 178 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def flatten_list(l):
|
| 182 |
+
"""Function to flatten a nested list."""
|
| 183 |
+
return list(itertools.chain.from_iterable(l))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def process_line(line):
|
| 187 |
+
"""
|
| 188 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
| 189 |
+
and categorize them into affected or unaffected based on the transformation.
|
| 190 |
+
|
| 191 |
+
Parameters:
|
| 192 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
| 193 |
+
sentence annotations.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
- tuple: A tuple containing three lists:
|
| 197 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
| 198 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
| 199 |
+
|
| 200 |
+
Note:
|
| 201 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
| 202 |
+
are expected to be available in the global scope.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
new_lines_affected = []
|
| 206 |
+
new_lines_unaffected = []
|
| 207 |
+
sents_unaffected = []
|
| 208 |
+
|
| 209 |
+
# Apply transformation to each sentence on line
|
| 210 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
| 211 |
+
|
| 212 |
+
tokens = perturbation_function(sent)
|
| 213 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
| 217 |
+
|
| 218 |
+
# Check if sent is affected
|
| 219 |
+
if affect_function(sent):
|
| 220 |
+
|
| 221 |
+
# Check if this affected sentence should be filtered or not
|
| 222 |
+
if filter_function(sent):
|
| 223 |
+
new_lines_affected.append(token_line + "\n")
|
| 224 |
+
|
| 225 |
+
else: # Unaffected sentences
|
| 226 |
+
new_lines_unaffected.append(token_line + "\n")
|
| 227 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
| 228 |
+
|
| 229 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
|
| 234 |
+
parser = argparse.ArgumentParser(
|
| 235 |
+
prog='Perturb BabyLM dataset',
|
| 236 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
| 237 |
+
parser.add_argument('perturbation_type',
|
| 238 |
+
default='all',
|
| 239 |
+
const='all',
|
| 240 |
+
nargs='?',
|
| 241 |
+
choices=PERTURBATIONS.keys(),
|
| 242 |
+
help='Perturbation function used to transform BabyLM dataset')
|
| 243 |
+
parser.add_argument('babylm_dataset',
|
| 244 |
+
default='all',
|
| 245 |
+
const='all',
|
| 246 |
+
nargs='?',
|
| 247 |
+
choices=BABYLM_SPLITS,
|
| 248 |
+
help='BabyLM dataset choice')
|
| 249 |
+
|
| 250 |
+
# Get args
|
| 251 |
+
args = parser.parse_args()
|
| 252 |
+
|
| 253 |
+
# Load dataset (only json files containing tagged data)
|
| 254 |
+
babylm_dataset = args.babylm_dataset
|
| 255 |
+
json_ext = "_parsed.json"
|
| 256 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 257 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 258 |
+
print("babylm_data:", babylm_data)
|
| 259 |
+
|
| 260 |
+
# Get perturbation, affect, and filter functions
|
| 261 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
| 262 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
| 263 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
| 264 |
+
gpt2_tokenizer = PERTURBATIONS[args.perturbation_type]['gpt2_tokenizer']
|
| 265 |
+
|
| 266 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 267 |
+
|
| 268 |
+
# Iterate over files and do transform
|
| 269 |
+
for file in babylm_data:
|
| 270 |
+
print(file)
|
| 271 |
+
f = open(file)
|
| 272 |
+
data = json.load(f)
|
| 273 |
+
f.close()
|
| 274 |
+
|
| 275 |
+
# Perturb data iteratively
|
| 276 |
+
results = []
|
| 277 |
+
for line in tqdm.tqdm(data):
|
| 278 |
+
results.append(process_line(line))
|
| 279 |
+
|
| 280 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
| 281 |
+
*results)
|
| 282 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 283 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 284 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
| 285 |
+
|
| 286 |
+
# Name new file
|
| 287 |
+
new_file_affected = os.path.basename(
|
| 288 |
+
file).replace(json_ext, "_affected.test")
|
| 289 |
+
new_file_unaffected = os.path.basename(
|
| 290 |
+
file).replace(json_ext, "_unaffected.test")
|
| 291 |
+
file_unaffected_sents = os.path.basename(
|
| 292 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
| 293 |
+
|
| 294 |
+
# Create directory
|
| 295 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/babylm_data_perturbed"
|
| 296 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
| 297 |
+
if not os.path.exists(directory_affected):
|
| 298 |
+
os.makedirs(directory_affected)
|
| 299 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
| 300 |
+
if not os.path.exists(directory_unaffected):
|
| 301 |
+
os.makedirs(directory_unaffected)
|
| 302 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
| 303 |
+
if not os.path.exists(directory_unaffected_sents):
|
| 304 |
+
os.makedirs(directory_unaffected_sents)
|
| 305 |
+
|
| 306 |
+
# Write files
|
| 307 |
+
write_file(directory_affected,
|
| 308 |
+
new_file_affected, new_lines_affected)
|
| 309 |
+
write_file(directory_unaffected,
|
| 310 |
+
new_file_unaffected, new_lines_unaffected)
|
| 311 |
+
write_file(directory_unaffected_sents,
|
| 312 |
+
file_unaffected_sents, unaffected_sents)
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 316 |
+
# Iterate over files and do transform
|
| 317 |
+
for file in babylm_data:
|
| 318 |
+
print(file)
|
| 319 |
+
f = open(file)
|
| 320 |
+
data = json.load(f)
|
| 321 |
+
f.close()
|
| 322 |
+
|
| 323 |
+
# Perturb data iteratively
|
| 324 |
+
results = []
|
| 325 |
+
for line in tqdm.tqdm(data):
|
| 326 |
+
results.append(process_line(line))
|
| 327 |
+
|
| 328 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
| 329 |
+
*results)
|
| 330 |
+
|
| 331 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 332 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 333 |
+
|
| 334 |
+
# Combine affected and unaffected sentences
|
| 335 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
| 336 |
+
|
| 337 |
+
# Name new file
|
| 338 |
+
if babylm_dataset == "dev":
|
| 339 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
| 340 |
+
elif babylm_dataset == 'unittest':
|
| 341 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
| 342 |
+
|
| 343 |
+
# Print strings for unittest
|
| 344 |
+
new_lines_decoded = [gpt2_tokenizer.decode(
|
| 345 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
| 346 |
+
new_lines_with_strings = []
|
| 347 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
| 348 |
+
new_lines_with_strings.append(tokens)
|
| 349 |
+
new_lines_with_strings.append(line)
|
| 350 |
+
new_lines = new_lines_with_strings
|
| 351 |
+
|
| 352 |
+
else:
|
| 353 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
| 354 |
+
|
| 355 |
+
# Create directory and write file
|
| 356 |
+
directory = f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
| 357 |
+
if not os.path.exists(directory):
|
| 358 |
+
os.makedirs(directory)
|
| 359 |
+
write_file(directory, new_file, new_lines)
|
data/perturb.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
# perturb.sh
|
| 3 |
+
# author: Julie Kallini
|
| 4 |
+
|
| 5 |
+
echo "
|
| 6 |
+
-------------------------------------------------------------------------------
|
| 7 |
+
Arguments
|
| 8 |
+
-------------------------------------------------------------------------------
|
| 9 |
+
"
|
| 10 |
+
echo "Perturbation type: $1"
|
| 11 |
+
echo "Train set: $2"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Create perturbed dataset for all splits
|
| 15 |
+
echo "
|
| 16 |
+
-------------------------------------------------------------------------------
|
| 17 |
+
Creating perturbed dataset for all splits
|
| 18 |
+
-------------------------------------------------------------------------------
|
| 19 |
+
"
|
| 20 |
+
|
| 21 |
+
cd ../data
|
| 22 |
+
|
| 23 |
+
echo "python3 perturb.py $1 $2"
|
| 24 |
+
python3 perturb.py $1 $2
|
| 25 |
+
echo "
|
| 26 |
+
python3 perturb.py $1 dev"
|
| 27 |
+
python3 perturb.py $1 dev
|
| 28 |
+
echo "
|
| 29 |
+
python3 perturb.py $1 test"
|
| 30 |
+
python3 perturb.py $1 test
|
| 31 |
+
echo "
|
| 32 |
+
python3 perturb.py $1 unittest"
|
| 33 |
+
python3 perturb.py $1 unittest
|
| 34 |
+
|
| 35 |
+
cd ..
|
data/perturb_llama.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# perturb.py
|
| 2 |
+
# Author: Julie Kallini
|
| 3 |
+
|
| 4 |
+
# For importing utils
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
| 9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
| 10 |
+
from glob import glob
|
| 11 |
+
import numpy as np
|
| 12 |
+
import itertools
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import tqdm
|
| 16 |
+
import argparse
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
MODEL_NAME = "Llama-3.2-3B"
|
| 20 |
+
|
| 21 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
| 22 |
+
"""Compare lines of two files after splitting them."""
|
| 23 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
| 24 |
+
for line1, line2 in zip(file1, file2):
|
| 25 |
+
# Split each line and compare the resulting lists
|
| 26 |
+
res1 = [i for i in line1.split() if int(
|
| 27 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 28 |
+
res2 = [i for i in line2.split() if int(
|
| 29 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 30 |
+
if res1 != res2:
|
| 31 |
+
print(line1)
|
| 32 |
+
print(line2)
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
# Check if one file has more lines than the other
|
| 36 |
+
if file1.readline() or file2.readline():
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
perturbation_pairs_3pres = [
|
| 43 |
+
("0tokens", "4tokens"),
|
| 44 |
+
("0tokens", "4words"),
|
| 45 |
+
("4tokens", "4words"),
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
| 49 |
+
|
| 50 |
+
test_data = itertools.product(
|
| 51 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
| 52 |
+
|
| 53 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
| 54 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
| 58 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
| 59 |
+
|
| 60 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 61 |
+
|
| 62 |
+
if split in ("100M", "10M"):
|
| 63 |
+
filename = f"{genre}.train"
|
| 64 |
+
elif split == "test_affected":
|
| 65 |
+
filename = f"{genre}_affected.test"
|
| 66 |
+
elif split == "test_unaffected":
|
| 67 |
+
filename = f"{genre}_unaffected.test"
|
| 68 |
+
elif split == "dev":
|
| 69 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
| 70 |
+
|
| 71 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
| 72 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
| 73 |
+
|
| 74 |
+
#Yj: compare two files in two paths
|
| 75 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
| 76 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
| 80 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 81 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 82 |
+
for line1, line2 in zip(file1, file2):
|
| 83 |
+
# Split each line and compare the resulting lists
|
| 84 |
+
line1_tokens = line1.split()
|
| 85 |
+
line2_tokens = line2.split()
|
| 86 |
+
|
| 87 |
+
# Get REV marker index
|
| 88 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
| 89 |
+
|
| 90 |
+
# Make sure tokens up to and including the marker are all the same
|
| 91 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
# Make sure reversal of rest of string is equal to identity
|
| 95 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
| 96 |
+
line1_tokens_rev.reverse()
|
| 97 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
# Check if one file has more lines than the other
|
| 101 |
+
if file1.readline() or file2.readline():
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
perturbation_pairs_reversal = [
|
| 108 |
+
("reversal", "reversal_identity"),
|
| 109 |
+
]
|
| 110 |
+
# Yj: 针对反转扰动对进行组合测试
|
| 111 |
+
|
| 112 |
+
test_data = itertools.product(
|
| 113 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 114 |
+
|
| 115 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 116 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
| 117 |
+
|
| 118 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 119 |
+
|
| 120 |
+
if split in ("100M", "10M"):
|
| 121 |
+
filename = f"{genre}.train"
|
| 122 |
+
elif split == "test_affected":
|
| 123 |
+
filename = f"{genre}_affected.test"
|
| 124 |
+
elif split == "test_unaffected":
|
| 125 |
+
filename = f"{genre}_unaffected.test"
|
| 126 |
+
elif split == "dev":
|
| 127 |
+
filename = f"{genre}.dev"
|
| 128 |
+
|
| 129 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 130 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 131 |
+
|
| 132 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
| 133 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
| 137 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 138 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 139 |
+
for line1, line2 in zip(file1, file2):
|
| 140 |
+
# Split each line and compare the resulting lists
|
| 141 |
+
line1_tokens = set(line1.split())
|
| 142 |
+
line2_tokens = set(line2.split())
|
| 143 |
+
if line1_tokens != line2_tokens:
|
| 144 |
+
print(line1.split())
|
| 145 |
+
print(line2.split())
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
# Check if one file has more lines than the other
|
| 149 |
+
if file1.readline() or file2.readline():
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
perturbation_pairs_reversal = [
|
| 156 |
+
("determiner_swap", "determiner_swap_identity"),
|
| 157 |
+
]
|
| 158 |
+
test_data = itertools.product(
|
| 159 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 160 |
+
|
| 161 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 162 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
| 163 |
+
|
| 164 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 165 |
+
|
| 166 |
+
if split in ("100M", "10M"):
|
| 167 |
+
filename = f"{genre}.train"
|
| 168 |
+
elif split == "test_affected":
|
| 169 |
+
filename = f"{genre}_affected.test"
|
| 170 |
+
elif split == "test_unaffected":
|
| 171 |
+
filename = f"{genre}_unaffected.test"
|
| 172 |
+
elif split == "dev":
|
| 173 |
+
filename = f"{genre}.dev"
|
| 174 |
+
|
| 175 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 176 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 177 |
+
|
| 178 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
| 179 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def flatten_list(l):
|
| 183 |
+
"""Function to flatten a nested list."""
|
| 184 |
+
return list(itertools.chain.from_iterable(l))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def process_line(line):
|
| 188 |
+
"""
|
| 189 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
| 190 |
+
and categorize them into affected or unaffected based on the transformation.
|
| 191 |
+
|
| 192 |
+
Parameters:
|
| 193 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
| 194 |
+
sentence annotations.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
- tuple: A tuple containing three lists:
|
| 198 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
| 199 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
| 200 |
+
|
| 201 |
+
Note:
|
| 202 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
| 203 |
+
are expected to be available in the global scope.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
new_lines_affected = []
|
| 207 |
+
new_lines_unaffected = []
|
| 208 |
+
sents_unaffected = []
|
| 209 |
+
|
| 210 |
+
# Apply transformation to each sentence on line
|
| 211 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
| 212 |
+
|
| 213 |
+
tokens = perturbation_function(sent)
|
| 214 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
| 218 |
+
|
| 219 |
+
# Check if sent is affected
|
| 220 |
+
if affect_function(sent):
|
| 221 |
+
|
| 222 |
+
# Check if this affected sentence should be filtered or not
|
| 223 |
+
if filter_function(sent):
|
| 224 |
+
new_lines_affected.append(token_line + "\n")
|
| 225 |
+
|
| 226 |
+
else: # Unaffected sentences
|
| 227 |
+
new_lines_unaffected.append(token_line + "\n")
|
| 228 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
| 229 |
+
|
| 230 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
|
| 235 |
+
parser = argparse.ArgumentParser(
|
| 236 |
+
prog='Perturb BabyLM dataset',
|
| 237 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
| 238 |
+
parser.add_argument('perturbation_type',
|
| 239 |
+
default='all',
|
| 240 |
+
const='all',
|
| 241 |
+
nargs='?',
|
| 242 |
+
choices=PERTURBATIONS.keys(),
|
| 243 |
+
help='Perturbation function used to transform BabyLM dataset')
|
| 244 |
+
parser.add_argument('babylm_dataset',
|
| 245 |
+
default='all',
|
| 246 |
+
const='all',
|
| 247 |
+
nargs='?',
|
| 248 |
+
choices=BABYLM_SPLITS,
|
| 249 |
+
help='BabyLM dataset choice')
|
| 250 |
+
|
| 251 |
+
# Get args
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
|
| 254 |
+
# Load dataset (only json files containing tagged data)
|
| 255 |
+
babylm_dataset = args.babylm_dataset
|
| 256 |
+
json_ext = "_parsed.json"
|
| 257 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 258 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 259 |
+
print("babylm_data:", babylm_data)
|
| 260 |
+
|
| 261 |
+
# Get perturbation, affect, and filter functions
|
| 262 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
| 263 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
| 264 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
| 265 |
+
llama_tokenizer = PERTURBATIONS[args.perturbation_type]['llama_tokenizer']
|
| 266 |
+
|
| 267 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 268 |
+
|
| 269 |
+
# Iterate over files and do transform
|
| 270 |
+
for file in babylm_data:
|
| 271 |
+
print(file)
|
| 272 |
+
f = open(file)
|
| 273 |
+
data = json.load(f)
|
| 274 |
+
f.close()
|
| 275 |
+
|
| 276 |
+
# Perturb data iteratively
|
| 277 |
+
results = []
|
| 278 |
+
for line in tqdm.tqdm(data):
|
| 279 |
+
results.append(process_line(line))
|
| 280 |
+
|
| 281 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
| 282 |
+
*results)
|
| 283 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 284 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 285 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
| 286 |
+
|
| 287 |
+
# Name new file
|
| 288 |
+
new_file_affected = os.path.basename(
|
| 289 |
+
file).replace(json_ext, "_affected.test")
|
| 290 |
+
new_file_unaffected = os.path.basename(
|
| 291 |
+
file).replace(json_ext, "_unaffected.test")
|
| 292 |
+
file_unaffected_sents = os.path.basename(
|
| 293 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
| 294 |
+
|
| 295 |
+
# Create directory
|
| 296 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}"
|
| 297 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
| 298 |
+
if not os.path.exists(directory_affected):
|
| 299 |
+
os.makedirs(directory_affected)
|
| 300 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
| 301 |
+
if not os.path.exists(directory_unaffected):
|
| 302 |
+
os.makedirs(directory_unaffected)
|
| 303 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
| 304 |
+
if not os.path.exists(directory_unaffected_sents):
|
| 305 |
+
os.makedirs(directory_unaffected_sents)
|
| 306 |
+
|
| 307 |
+
# Write files
|
| 308 |
+
write_file(directory_affected,
|
| 309 |
+
new_file_affected, new_lines_affected)
|
| 310 |
+
write_file(directory_unaffected,
|
| 311 |
+
new_file_unaffected, new_lines_unaffected)
|
| 312 |
+
write_file(directory_unaffected_sents,
|
| 313 |
+
file_unaffected_sents, unaffected_sents)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 317 |
+
# Iterate over files and do transform
|
| 318 |
+
for file in babylm_data:
|
| 319 |
+
print(file)
|
| 320 |
+
f = open(file)
|
| 321 |
+
data = json.load(f)
|
| 322 |
+
f.close()
|
| 323 |
+
|
| 324 |
+
# Perturb data iteratively
|
| 325 |
+
results = []
|
| 326 |
+
for line in tqdm.tqdm(data):
|
| 327 |
+
results.append(process_line(line))
|
| 328 |
+
|
| 329 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
| 330 |
+
*results)
|
| 331 |
+
|
| 332 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 333 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 334 |
+
|
| 335 |
+
# Combine affected and unaffected sentences
|
| 336 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
| 337 |
+
|
| 338 |
+
# Name new file
|
| 339 |
+
if babylm_dataset == "dev":
|
| 340 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
| 341 |
+
elif babylm_dataset == 'unittest':
|
| 342 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
| 343 |
+
|
| 344 |
+
# Print strings for unittest
|
| 345 |
+
new_lines_decoded = [llama_tokenizer.decode(
|
| 346 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
| 347 |
+
new_lines_with_strings = []
|
| 348 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
| 349 |
+
new_lines_with_strings.append(tokens)
|
| 350 |
+
new_lines_with_strings.append(line)
|
| 351 |
+
new_lines = new_lines_with_strings
|
| 352 |
+
|
| 353 |
+
else:
|
| 354 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
| 355 |
+
|
| 356 |
+
# Create directory and write file
|
| 357 |
+
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
| 358 |
+
print("directory:", directory)
|
| 359 |
+
if not os.path.exists(directory):
|
| 360 |
+
os.makedirs(directory)
|
| 361 |
+
write_file(directory, new_file, new_lines)
|
data/perturb_model.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Define your perturbations and BabyLM splits
|
| 4 |
+
PERTURBATIONS=("hop_control" "hop_tokens4" "hop_words4" "reverse_control" "reverse_partial" "reverse_full" "shuffle_control"
|
| 5 |
+
"shuffle_nondeterministic" "shuffle_deterministic21" "shuffle_deterministic57" "shuffle_deterministic84" "shuffle_local3"
|
| 6 |
+
"shuffle_local5" "shuffle_local10" "shuffle_even_odd")
|
| 7 |
+
|
| 8 |
+
# BABYLM_SPLITS=("100M" "10M" "dev" "test" "unittest") # Add more splits as needed
|
| 9 |
+
BABYLM_SPLITS=("dev")
|
| 10 |
+
|
| 11 |
+
# Specify the GPUs to use
|
| 12 |
+
SPECIFIED_GPUS=(1 2 3 4 5 6 7) # Set these to the GPUs you want to use
|
| 13 |
+
|
| 14 |
+
# Store PIDs and Gpu mapping to track running processes
|
| 15 |
+
declare -A GPU_PROCESS_MAP
|
| 16 |
+
|
| 17 |
+
# Iterate over all combinations of perturbations and splits
|
| 18 |
+
for perturbation in "${PERTURBATIONS[@]}"; do
|
| 19 |
+
for split in "${BABYLM_SPLITS[@]}"; do
|
| 20 |
+
|
| 21 |
+
# Check for a free GPU
|
| 22 |
+
while true; do
|
| 23 |
+
for gpu in "${SPECIFIED_GPUS[@]}"; do
|
| 24 |
+
# Check if there's no process associated with this GPU
|
| 25 |
+
if ! ps -p ${GPU_PROCESS_MAP[$gpu]} > /dev/null 2>&1; then
|
| 26 |
+
# Run the Python perturbation script on the available GPU
|
| 27 |
+
CUDA_VISIBLE_DEVICES=$gpu python perturb_llama.py "$perturbation" "$split" &
|
| 28 |
+
GPU_PROCESS_MAP[$gpu]=$!
|
| 29 |
+
echo "Running on GPU $gpu: Perturbation=$perturbation, Split=$split, PID=$!"
|
| 30 |
+
break 2 # Break out of the loops once a GPU is assigned
|
| 31 |
+
fi
|
| 32 |
+
done
|
| 33 |
+
sleep 1 # Wait a second before checking again
|
| 34 |
+
done
|
| 35 |
+
done
|
| 36 |
+
done
|
| 37 |
+
|
| 38 |
+
# Wait for all processes to finish
|
| 39 |
+
wait
|
| 40 |
+
echo "All tasks completed."
|
data/perturb_qwen.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# perturb.py
|
| 2 |
+
# Author: Julie Kallini
|
| 3 |
+
|
| 4 |
+
# For importing utils
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
from utils_qwen import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
|
| 9 |
+
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
|
| 10 |
+
from glob import glob
|
| 11 |
+
import numpy as np
|
| 12 |
+
import itertools
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import tqdm
|
| 16 |
+
import argparse
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
MODEL_NAME = "Qwen2.5-7B"
|
| 20 |
+
|
| 21 |
+
def lines_equivalent_3pres(file1_path, file2_path):
|
| 22 |
+
"""Compare lines of two files after splitting them."""
|
| 23 |
+
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
|
| 24 |
+
for line1, line2 in zip(file1, file2):
|
| 25 |
+
# Split each line and compare the resulting lists
|
| 26 |
+
res1 = [i for i in line1.split() if int(
|
| 27 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 28 |
+
res2 = [i for i in line2.split() if int(
|
| 29 |
+
i) not in (marker_sg_token, marker_pl_token)]
|
| 30 |
+
if res1 != res2:
|
| 31 |
+
print(line1)
|
| 32 |
+
print(line2)
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
# Check if one file has more lines than the other
|
| 36 |
+
if file1.readline() or file2.readline():
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
perturbation_pairs_3pres = [
|
| 43 |
+
("0tokens", "4tokens"),
|
| 44 |
+
("0tokens", "4words"),
|
| 45 |
+
("4tokens", "4words"),
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
|
| 49 |
+
|
| 50 |
+
test_data = itertools.product(
|
| 51 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
|
| 52 |
+
|
| 53 |
+
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
|
| 54 |
+
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
|
| 58 |
+
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
|
| 59 |
+
|
| 60 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 61 |
+
|
| 62 |
+
if split in ("100M", "10M"):
|
| 63 |
+
filename = f"{genre}.train"
|
| 64 |
+
elif split == "test_affected":
|
| 65 |
+
filename = f"{genre}_affected.test"
|
| 66 |
+
elif split == "test_unaffected":
|
| 67 |
+
filename = f"{genre}_unaffected.test"
|
| 68 |
+
elif split == "dev":
|
| 69 |
+
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
|
| 70 |
+
|
| 71 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
|
| 72 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
|
| 73 |
+
|
| 74 |
+
#Yj: compare two files in two paths
|
| 75 |
+
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
|
| 76 |
+
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def lines_equivalent_reversal(rev_path, ident_path):
|
| 80 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 81 |
+
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 82 |
+
for line1, line2 in zip(file1, file2):
|
| 83 |
+
# Split each line and compare the resulting lists
|
| 84 |
+
line1_tokens = line1.split()
|
| 85 |
+
line2_tokens = line2.split()
|
| 86 |
+
|
| 87 |
+
# Get REV marker index
|
| 88 |
+
marker_index = line1_tokens.index(str(marker_rev_token))
|
| 89 |
+
|
| 90 |
+
# Make sure tokens up to and including the marker are all the same
|
| 91 |
+
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
# Make sure reversal of rest of string is equal to identity
|
| 95 |
+
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
|
| 96 |
+
line1_tokens_rev.reverse()
|
| 97 |
+
if line1_tokens_rev != line2_tokens[marker_index+1:]:
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
# Check if one file has more lines than the other
|
| 101 |
+
if file1.readline() or file2.readline():
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
perturbation_pairs_reversal = [
|
| 108 |
+
("reversal", "reversal_identity"),
|
| 109 |
+
]
|
| 110 |
+
# Yj: 针对反转扰动对进行组合测试
|
| 111 |
+
|
| 112 |
+
test_data = itertools.product(
|
| 113 |
+
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 114 |
+
|
| 115 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 116 |
+
def test_reversal_all_equivalent(split, genre, perturbation_pair):
|
| 117 |
+
|
| 118 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 119 |
+
|
| 120 |
+
if split in ("100M", "10M"):
|
| 121 |
+
filename = f"{genre}.train"
|
| 122 |
+
elif split == "test_affected":
|
| 123 |
+
filename = f"{genre}_affected.test"
|
| 124 |
+
elif split == "test_unaffected":
|
| 125 |
+
filename = f"{genre}_unaffected.test"
|
| 126 |
+
elif split == "dev":
|
| 127 |
+
filename = f"{genre}.dev"
|
| 128 |
+
|
| 129 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 130 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 131 |
+
|
| 132 |
+
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
|
| 133 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def lines_equivalent_determiner_swap(det_path, ident_path):
|
| 137 |
+
"""Compare lines of reversal file and identity file after splitting them."""
|
| 138 |
+
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
|
| 139 |
+
for line1, line2 in zip(file1, file2):
|
| 140 |
+
# Split each line and compare the resulting lists
|
| 141 |
+
line1_tokens = set(line1.split())
|
| 142 |
+
line2_tokens = set(line2.split())
|
| 143 |
+
if line1_tokens != line2_tokens:
|
| 144 |
+
print(line1.split())
|
| 145 |
+
print(line2.split())
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
# Check if one file has more lines than the other
|
| 149 |
+
if file1.readline() or file2.readline():
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
perturbation_pairs_reversal = [
|
| 156 |
+
("determiner_swap", "determiner_swap_identity"),
|
| 157 |
+
]
|
| 158 |
+
test_data = itertools.product(
|
| 159 |
+
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
|
| 160 |
+
|
| 161 |
+
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
|
| 162 |
+
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
|
| 163 |
+
|
| 164 |
+
perturbation1, perturbation2 = perturbation_pair
|
| 165 |
+
|
| 166 |
+
if split in ("100M", "10M"):
|
| 167 |
+
filename = f"{genre}.train"
|
| 168 |
+
elif split == "test_affected":
|
| 169 |
+
filename = f"{genre}_affected.test"
|
| 170 |
+
elif split == "test_unaffected":
|
| 171 |
+
filename = f"{genre}_unaffected.test"
|
| 172 |
+
elif split == "dev":
|
| 173 |
+
filename = f"{genre}.dev"
|
| 174 |
+
|
| 175 |
+
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
|
| 176 |
+
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
|
| 177 |
+
|
| 178 |
+
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
|
| 179 |
+
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def flatten_list(l):
|
| 183 |
+
"""Function to flatten a nested list."""
|
| 184 |
+
return list(itertools.chain.from_iterable(l))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def process_line(line):
|
| 188 |
+
"""
|
| 189 |
+
Process a given line from the dataset, apply transformations to its sentences,
|
| 190 |
+
and categorize them into affected or unaffected based on the transformation.
|
| 191 |
+
|
| 192 |
+
Parameters:
|
| 193 |
+
- line (dict): A dictionary representing a line from the dataset, which contains
|
| 194 |
+
sentence annotations.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
- tuple: A tuple containing three lists:
|
| 198 |
+
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
|
| 199 |
+
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
|
| 200 |
+
|
| 201 |
+
Note:
|
| 202 |
+
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
|
| 203 |
+
are expected to be available in the global scope.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
new_lines_affected = []
|
| 207 |
+
new_lines_unaffected = []
|
| 208 |
+
sents_unaffected = []
|
| 209 |
+
|
| 210 |
+
# Apply transformation to each sentence on line
|
| 211 |
+
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
|
| 212 |
+
|
| 213 |
+
tokens = perturbation_function(sent)
|
| 214 |
+
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
token_line = " ".join([str(tok) for tok in tokens])
|
| 218 |
+
|
| 219 |
+
# Check if sent is affected
|
| 220 |
+
if affect_function(sent):
|
| 221 |
+
|
| 222 |
+
# Check if this affected sentence should be filtered or not
|
| 223 |
+
if filter_function(sent):
|
| 224 |
+
new_lines_affected.append(token_line + "\n")
|
| 225 |
+
|
| 226 |
+
else: # Unaffected sentences
|
| 227 |
+
new_lines_unaffected.append(token_line + "\n")
|
| 228 |
+
sents_unaffected.append(sent["sent_text"] + "\n")
|
| 229 |
+
|
| 230 |
+
return new_lines_affected, new_lines_unaffected, sents_unaffected
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
|
| 235 |
+
parser = argparse.ArgumentParser(
|
| 236 |
+
prog='Perturb BabyLM dataset',
|
| 237 |
+
description='Perturb BabyLM dataset by altering POS-tagged data')
|
| 238 |
+
parser.add_argument('perturbation_type',
|
| 239 |
+
default='all',
|
| 240 |
+
const='all',
|
| 241 |
+
nargs='?',
|
| 242 |
+
choices=PERTURBATIONS.keys(),
|
| 243 |
+
help='Perturbation function used to transform BabyLM dataset')
|
| 244 |
+
parser.add_argument('babylm_dataset',
|
| 245 |
+
default='all',
|
| 246 |
+
const='all',
|
| 247 |
+
nargs='?',
|
| 248 |
+
choices=BABYLM_SPLITS,
|
| 249 |
+
help='BabyLM dataset choice')
|
| 250 |
+
|
| 251 |
+
# Get args
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
|
| 254 |
+
# Load dataset (only json files containing tagged data)
|
| 255 |
+
babylm_dataset = args.babylm_dataset
|
| 256 |
+
json_ext = "_parsed.json"
|
| 257 |
+
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 258 |
+
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
|
| 259 |
+
print("babylm_data:", babylm_data)
|
| 260 |
+
|
| 261 |
+
# Get perturbation, affect, and filter functions
|
| 262 |
+
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
|
| 263 |
+
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
|
| 264 |
+
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
|
| 265 |
+
qwen_tokenizer = PERTURBATIONS[args.perturbation_type]['qwen_tokenizer']
|
| 266 |
+
|
| 267 |
+
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 268 |
+
|
| 269 |
+
# Iterate over files and do transform
|
| 270 |
+
for file in babylm_data:
|
| 271 |
+
print(file)
|
| 272 |
+
f = open(file)
|
| 273 |
+
data = json.load(f)
|
| 274 |
+
f.close()
|
| 275 |
+
|
| 276 |
+
# Perturb data iteratively
|
| 277 |
+
results = []
|
| 278 |
+
for line in tqdm.tqdm(data):
|
| 279 |
+
results.append(process_line(line))
|
| 280 |
+
|
| 281 |
+
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
|
| 282 |
+
*results)
|
| 283 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 284 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 285 |
+
unaffected_sents = flatten_list(unaffected_sents)
|
| 286 |
+
|
| 287 |
+
# Name new file
|
| 288 |
+
new_file_affected = os.path.basename(
|
| 289 |
+
file).replace(json_ext, "_affected.test")
|
| 290 |
+
new_file_unaffected = os.path.basename(
|
| 291 |
+
file).replace(json_ext, "_unaffected.test")
|
| 292 |
+
file_unaffected_sents = os.path.basename(
|
| 293 |
+
file).replace(json_ext, "_unaffected_sents.test")
|
| 294 |
+
|
| 295 |
+
# Create directory
|
| 296 |
+
data_write_directory = f"{BABYLM_DATA_PATH}/Qwen_perturbed_data/{MODEL_NAME}"
|
| 297 |
+
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
|
| 298 |
+
if not os.path.exists(directory_affected):
|
| 299 |
+
os.makedirs(directory_affected)
|
| 300 |
+
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
|
| 301 |
+
if not os.path.exists(directory_unaffected):
|
| 302 |
+
os.makedirs(directory_unaffected)
|
| 303 |
+
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
|
| 304 |
+
if not os.path.exists(directory_unaffected_sents):
|
| 305 |
+
os.makedirs(directory_unaffected_sents)
|
| 306 |
+
|
| 307 |
+
# Write files
|
| 308 |
+
write_file(directory_affected,
|
| 309 |
+
new_file_affected, new_lines_affected)
|
| 310 |
+
write_file(directory_unaffected,
|
| 311 |
+
new_file_unaffected, new_lines_unaffected)
|
| 312 |
+
write_file(directory_unaffected_sents,
|
| 313 |
+
file_unaffected_sents, unaffected_sents)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
|
| 317 |
+
# Iterate over files and do transform
|
| 318 |
+
for file in babylm_data:
|
| 319 |
+
print(file)
|
| 320 |
+
f = open(file)
|
| 321 |
+
data = json.load(f)
|
| 322 |
+
f.close()
|
| 323 |
+
|
| 324 |
+
# Perturb data iteratively
|
| 325 |
+
results = []
|
| 326 |
+
for line in tqdm.tqdm(data):
|
| 327 |
+
results.append(process_line(line))
|
| 328 |
+
|
| 329 |
+
new_lines_affected, new_lines_unaffected, _ = zip(
|
| 330 |
+
*results)
|
| 331 |
+
|
| 332 |
+
new_lines_affected = flatten_list(new_lines_affected)
|
| 333 |
+
new_lines_unaffected = flatten_list(new_lines_unaffected)
|
| 334 |
+
|
| 335 |
+
# Combine affected and unaffected sentences
|
| 336 |
+
new_lines = new_lines_unaffected + new_lines_affected
|
| 337 |
+
|
| 338 |
+
# Name new file
|
| 339 |
+
if babylm_dataset == "dev":
|
| 340 |
+
new_file = os.path.basename(file).replace(json_ext, ".dev")
|
| 341 |
+
elif babylm_dataset == 'unittest':
|
| 342 |
+
new_file = os.path.basename(file).replace(json_ext, ".test")
|
| 343 |
+
|
| 344 |
+
# Print strings for unittest
|
| 345 |
+
new_lines_decoded = [qwen_tokenizer.decode(
|
| 346 |
+
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
|
| 347 |
+
new_lines_with_strings = []
|
| 348 |
+
for tokens, line in list(zip(new_lines, new_lines_decoded)):
|
| 349 |
+
new_lines_with_strings.append(tokens)
|
| 350 |
+
new_lines_with_strings.append(line)
|
| 351 |
+
new_lines = new_lines_with_strings
|
| 352 |
+
|
| 353 |
+
else:
|
| 354 |
+
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
|
| 355 |
+
|
| 356 |
+
# Create directory and write file
|
| 357 |
+
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
|
| 358 |
+
print("directory:", directory)
|
| 359 |
+
if not os.path.exists(directory):
|
| 360 |
+
os.makedirs(directory)
|
| 361 |
+
write_file(directory, new_file, new_lines)
|
data/tag.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tag.py
|
| 2 |
+
# Author: Julie Kallini
|
| 3 |
+
|
| 4 |
+
# For importing utils
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import glob
|
| 10 |
+
import tqdm
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import stanza
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
|
| 18 |
+
test_original_files = [f for f in test_all_files if ".json" not in f]
|
| 19 |
+
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
|
| 20 |
+
test_cases = list(zip(test_original_files, test_json_files))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.mark.parametrize("original_file, json_file", test_cases)
|
| 24 |
+
def test_equivalent_lines(original_file, json_file):
|
| 25 |
+
|
| 26 |
+
# Read lines of file and remove all whitespace
|
| 27 |
+
original_file = open(original_file)
|
| 28 |
+
original_data = "".join(original_file.readlines())
|
| 29 |
+
original_data = "".join(original_data.split())
|
| 30 |
+
|
| 31 |
+
json_file = open(json_file)
|
| 32 |
+
json_lines = json.load(json_file)
|
| 33 |
+
json_data = ""
|
| 34 |
+
for line in json_lines:
|
| 35 |
+
for sent in line["sent_annotations"]:
|
| 36 |
+
json_data += sent["sent_text"]
|
| 37 |
+
json_data = "".join(json_data.split())
|
| 38 |
+
|
| 39 |
+
# Test equivalence
|
| 40 |
+
assert (original_data == json_data)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def __get_constituency_parse(sent, nlp):
|
| 44 |
+
|
| 45 |
+
# Try parsing the doc
|
| 46 |
+
try:
|
| 47 |
+
parse_doc = nlp(sent.text)
|
| 48 |
+
except:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
# Get set of constituency parse trees
|
| 52 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
| 53 |
+
|
| 54 |
+
# Join parse trees and add ROOT
|
| 55 |
+
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
|
| 56 |
+
return constituency_parse
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
|
| 61 |
+
parser = argparse.ArgumentParser(
|
| 62 |
+
prog='Tag BabyLM dataset',
|
| 63 |
+
description='Tag BabyLM dataset using Stanza')
|
| 64 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
| 65 |
+
nargs='+', help="Path to file(s)")
|
| 66 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
| 67 |
+
help="Include constituency parse")
|
| 68 |
+
|
| 69 |
+
# Get args
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
# Init Stanza NLP tools
|
| 73 |
+
nlp1 = stanza.Pipeline(
|
| 74 |
+
lang='en',
|
| 75 |
+
processors='tokenize, pos, lemma',
|
| 76 |
+
package="default_accurate",
|
| 77 |
+
use_gpu=True)
|
| 78 |
+
|
| 79 |
+
# If constituency parse is needed, init second Stanza parser
|
| 80 |
+
if args.parse:
|
| 81 |
+
nlp2 = stanza.Pipeline(lang='en',
|
| 82 |
+
processors='tokenize,pos,constituency',
|
| 83 |
+
package="default_accurate",
|
| 84 |
+
use_gpu=True)
|
| 85 |
+
|
| 86 |
+
# BATCH_SIZE = 5000
|
| 87 |
+
BATCH_SIZE=100
|
| 88 |
+
|
| 89 |
+
# Iterate over BabyLM files
|
| 90 |
+
for file in args.path:
|
| 91 |
+
|
| 92 |
+
print(file.name)
|
| 93 |
+
lines = file.readlines()
|
| 94 |
+
|
| 95 |
+
# Strip lines and join text
|
| 96 |
+
print("Concatenating lines...")
|
| 97 |
+
lines = [l.strip() for l in lines]
|
| 98 |
+
line_batches = [lines[i:i + BATCH_SIZE]
|
| 99 |
+
for i in range(0, len(lines), BATCH_SIZE)]
|
| 100 |
+
text_batches = [" ".join(l) for l in line_batches]
|
| 101 |
+
|
| 102 |
+
# Iterate over lines in file and track annotations
|
| 103 |
+
line_annotations = []
|
| 104 |
+
print("Segmenting and parsing text batches...")
|
| 105 |
+
for text in tqdm.tqdm(text_batches):
|
| 106 |
+
# Tokenize text with stanza
|
| 107 |
+
doc = nlp1(text)
|
| 108 |
+
|
| 109 |
+
# Iterate over sents in the line and track annotations
|
| 110 |
+
sent_annotations = []
|
| 111 |
+
for sent in doc.sentences:
|
| 112 |
+
|
| 113 |
+
# Iterate over words in sent and track annotations
|
| 114 |
+
word_annotations = []
|
| 115 |
+
for token, word in zip(sent.tokens, sent.words):
|
| 116 |
+
wa = {
|
| 117 |
+
'id': word.id,
|
| 118 |
+
'text': word.text,
|
| 119 |
+
'lemma': word.lemma,
|
| 120 |
+
'upos': word.upos,
|
| 121 |
+
'xpos': word.xpos,
|
| 122 |
+
'feats': word.feats,
|
| 123 |
+
'start_char': token.start_char,
|
| 124 |
+
'end_char': token.end_char
|
| 125 |
+
}
|
| 126 |
+
word_annotations.append(wa) # Track word annotation
|
| 127 |
+
|
| 128 |
+
# Get constituency parse if needed
|
| 129 |
+
if args.parse:
|
| 130 |
+
constituency_parse = __get_constituency_parse(sent, nlp2)
|
| 131 |
+
sa = {
|
| 132 |
+
'sent_text': sent.text,
|
| 133 |
+
'constituency_parse': constituency_parse,
|
| 134 |
+
'word_annotations': word_annotations,
|
| 135 |
+
}
|
| 136 |
+
else:
|
| 137 |
+
sa = {
|
| 138 |
+
'sent_text': sent.text,
|
| 139 |
+
'word_annotations': word_annotations,
|
| 140 |
+
}
|
| 141 |
+
sent_annotations.append(sa) # Track sent annotation
|
| 142 |
+
|
| 143 |
+
la = {
|
| 144 |
+
'sent_annotations': sent_annotations
|
| 145 |
+
}
|
| 146 |
+
line_annotations.append(la) # Track line annotation
|
| 147 |
+
|
| 148 |
+
# Write annotations to file as a JSON
|
| 149 |
+
print("Writing JSON outfile...")
|
| 150 |
+
ext = '_parsed.json' if args.parse else '.json'
|
| 151 |
+
json_filename = os.path.splitext(file.name)[0] + ext
|
| 152 |
+
with open(json_filename, "w") as outfile:
|
| 153 |
+
json.dump(line_annotations, outfile, indent=4)
|
data/tag_1.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tag.py
|
| 2 |
+
# Author: Julie Kallini
|
| 3 |
+
|
| 4 |
+
# For importing utils
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append("..")
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import glob
|
| 10 |
+
import tqdm
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import stanza
|
| 14 |
+
import json
|
| 15 |
+
from transformers import AutoTokenizer
|
| 16 |
+
|
| 17 |
+
# Define the function to chunk text
|
| 18 |
+
def chunk_text(text, tokenizer, max_length=512):
|
| 19 |
+
tokens = tokenizer(text)['input_ids']
|
| 20 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
| 21 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
| 22 |
+
|
| 23 |
+
# Test case for checking equivalence of original and parsed files
|
| 24 |
+
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
|
| 25 |
+
test_original_files = [f for f in test_all_files if ".json" not in f]
|
| 26 |
+
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
|
| 27 |
+
test_cases = list(zip(test_original_files, test_json_files))
|
| 28 |
+
|
| 29 |
+
@pytest.mark.parametrize("original_file, json_file", test_cases)
|
| 30 |
+
def test_equivalent_lines(original_file, json_file):
|
| 31 |
+
|
| 32 |
+
# Read lines of file and remove all whitespace
|
| 33 |
+
original_file = open(original_file)
|
| 34 |
+
original_data = "".join(original_file.readlines())
|
| 35 |
+
original_data = "".join(original_data.split())
|
| 36 |
+
|
| 37 |
+
json_file = open(json_file)
|
| 38 |
+
json_lines = json.load(json_file)
|
| 39 |
+
json_data = ""
|
| 40 |
+
for line in json_lines:
|
| 41 |
+
for sent in line["sent_annotations"]:
|
| 42 |
+
json_data += sent["sent_text"]
|
| 43 |
+
json_data = "".join(json_data.split())
|
| 44 |
+
|
| 45 |
+
# Test equivalence
|
| 46 |
+
assert (original_data == json_data)
|
| 47 |
+
|
| 48 |
+
# Constituency parsing function
|
| 49 |
+
def __get_constituency_parse(sent, nlp):
|
| 50 |
+
|
| 51 |
+
# Try parsing the doc
|
| 52 |
+
try:
|
| 53 |
+
parse_doc = nlp(sent.text)
|
| 54 |
+
except:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
# Get set of constituency parse trees
|
| 58 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
| 59 |
+
|
| 60 |
+
# Join parse trees and add ROOT
|
| 61 |
+
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
|
| 62 |
+
return constituency_parse
|
| 63 |
+
|
| 64 |
+
# Main function
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
|
| 67 |
+
parser = argparse.ArgumentParser(
|
| 68 |
+
prog='Tag BabyLM dataset',
|
| 69 |
+
description='Tag BabyLM dataset using Stanza')
|
| 70 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
| 71 |
+
nargs='+', help="Path to file(s)")
|
| 72 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
| 73 |
+
help="Include constituency parse")
|
| 74 |
+
|
| 75 |
+
# Get args
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
|
| 78 |
+
# Init Stanza NLP tools
|
| 79 |
+
nlp1 = stanza.Pipeline(
|
| 80 |
+
lang='en',
|
| 81 |
+
processors='tokenize, pos, lemma',
|
| 82 |
+
package="default_accurate",
|
| 83 |
+
use_gpu=True)
|
| 84 |
+
|
| 85 |
+
# If constituency parse is needed, init second Stanza parser
|
| 86 |
+
if args.parse:
|
| 87 |
+
nlp2 = stanza.Pipeline(lang='en',
|
| 88 |
+
processors='tokenize,pos,constituency',
|
| 89 |
+
package="default_accurate",
|
| 90 |
+
use_gpu=True)
|
| 91 |
+
|
| 92 |
+
BATCH_SIZE = 100
|
| 93 |
+
|
| 94 |
+
# Tokenizer for splitting long text
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 96 |
+
|
| 97 |
+
# Iterate over BabyLM files
|
| 98 |
+
for file in args.path:
|
| 99 |
+
|
| 100 |
+
print(file.name)
|
| 101 |
+
lines = file.readlines()
|
| 102 |
+
|
| 103 |
+
# Strip lines and join text
|
| 104 |
+
print("Concatenating lines...")
|
| 105 |
+
lines = [l.strip() for l in lines]
|
| 106 |
+
line_batches = [lines[i:i + BATCH_SIZE]
|
| 107 |
+
for i in range(0, len(lines), BATCH_SIZE)]
|
| 108 |
+
text_batches = [" ".join(l) for l in line_batches]
|
| 109 |
+
|
| 110 |
+
# Iterate over lines in file and track annotations
|
| 111 |
+
line_annotations = []
|
| 112 |
+
print("Segmenting and parsing text batches...")
|
| 113 |
+
for text in tqdm.tqdm(text_batches):
|
| 114 |
+
# Split the text into chunks if it exceeds the max length
|
| 115 |
+
text_chunks = chunk_text(text, tokenizer)
|
| 116 |
+
|
| 117 |
+
# Iterate over each chunk
|
| 118 |
+
for chunk in text_chunks:
|
| 119 |
+
# Tokenize text with stanza
|
| 120 |
+
doc = nlp1(chunk)
|
| 121 |
+
|
| 122 |
+
# Iterate over sentences in the line and track annotations
|
| 123 |
+
sent_annotations = []
|
| 124 |
+
for sent in doc.sentences:
|
| 125 |
+
|
| 126 |
+
# Iterate over words in the sentence and track annotations
|
| 127 |
+
word_annotations = []
|
| 128 |
+
for token, word in zip(sent.tokens, sent.words):
|
| 129 |
+
wa = {
|
| 130 |
+
'id': word.id,
|
| 131 |
+
'text': word.text,
|
| 132 |
+
'lemma': word.lemma,
|
| 133 |
+
'upos': word.upos,
|
| 134 |
+
'xpos': word.xpos,
|
| 135 |
+
'feats': word.feats,
|
| 136 |
+
'start_char': token.start_char,
|
| 137 |
+
'end_char': token.end_char
|
| 138 |
+
}
|
| 139 |
+
word_annotations.append(wa) # Track word annotation
|
| 140 |
+
|
| 141 |
+
# Get constituency parse if needed
|
| 142 |
+
if args.parse:
|
| 143 |
+
constituency_parse = __get_constituency_parse(sent, nlp2)
|
| 144 |
+
sa = {
|
| 145 |
+
'sent_text': sent.text,
|
| 146 |
+
'constituency_parse': constituency_parse,
|
| 147 |
+
'word_annotations': word_annotations,
|
| 148 |
+
}
|
| 149 |
+
else:
|
| 150 |
+
sa = {
|
| 151 |
+
'sent_text': sent.text,
|
| 152 |
+
'word_annotations': word_annotations,
|
| 153 |
+
}
|
| 154 |
+
sent_annotations.append(sa) # Track sent annotation
|
| 155 |
+
|
| 156 |
+
la = {
|
| 157 |
+
'sent_annotations': sent_annotations
|
| 158 |
+
}
|
| 159 |
+
line_annotations.append(la) # Track line annotation
|
| 160 |
+
|
| 161 |
+
# Write annotations to file as a JSON
|
| 162 |
+
print("Writing JSON outfile...")
|
| 163 |
+
ext = '_parsed.json' if args.parse else '.json'
|
| 164 |
+
json_filename = os.path.splitext(file.name)[0] + ext
|
| 165 |
+
with open(json_filename, "w") as outfile:
|
| 166 |
+
json.dump(line_annotations, outfile, indent=4)
|
data/tag_distributed.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# the files can be processed on different gpus, each file is processed on a gpu
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("..")
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import glob
|
| 9 |
+
import tqdm
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import stanza
|
| 13 |
+
import json
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
|
| 16 |
+
def chunk_text(text, tokenizer, max_length=512):
|
| 17 |
+
tokens = tokenizer(text)['input_ids']
|
| 18 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
| 19 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
| 20 |
+
|
| 21 |
+
def init_distributed_mode():
|
| 22 |
+
dist.init_process_group(backend='nccl')
|
| 23 |
+
rank = dist.get_rank()
|
| 24 |
+
torch.cuda.set_device(rank) # 使用rank指定GPU
|
| 25 |
+
return rank
|
| 26 |
+
|
| 27 |
+
def run_on_gpu(rank, args, tokenizer, nlp1, nlp2):
|
| 28 |
+
print(f"Running on Rank {rank}, using GPU {torch.cuda.current_device()}")
|
| 29 |
+
print(f"Rank {rank}, GPU {torch.cuda.current_device()} started")
|
| 30 |
+
files_per_gpu = len(args.path) // dist.get_world_size()
|
| 31 |
+
start_idx = rank * files_per_gpu
|
| 32 |
+
end_idx = start_idx + files_per_gpu if rank != dist.get_world_size() - 1 else len(args.path)
|
| 33 |
+
gpu_files = args.path[start_idx:end_idx]
|
| 34 |
+
|
| 35 |
+
for file in gpu_files:
|
| 36 |
+
print(f"GPU {rank}: Processing {file.name}")
|
| 37 |
+
lines = file.readlines()
|
| 38 |
+
|
| 39 |
+
lines = [l.strip() for l in lines]
|
| 40 |
+
line_batches = [lines[i:i + BATCH_SIZE] for i in range(0, len(lines), BATCH_SIZE)]
|
| 41 |
+
text_batches = [" ".join(l) for l in line_batches]
|
| 42 |
+
|
| 43 |
+
line_annotations = []
|
| 44 |
+
for text in tqdm.tqdm(text_batches, desc=f"GPU {rank}"):
|
| 45 |
+
text_chunks = chunk_text(text, tokenizer)
|
| 46 |
+
for chunk in text_chunks:
|
| 47 |
+
doc = nlp1(chunk)
|
| 48 |
+
sent_annotations = []
|
| 49 |
+
for sent in doc.sentences:
|
| 50 |
+
word_annotations = []
|
| 51 |
+
for token, word in zip(sent.tokens, sent.words):
|
| 52 |
+
wa = {
|
| 53 |
+
'id': word.id,
|
| 54 |
+
'text': word.text,
|
| 55 |
+
'lemma': word.lemma,
|
| 56 |
+
'upos': word.upos,
|
| 57 |
+
'xpos': word.xpos,
|
| 58 |
+
'feats': word.feats,
|
| 59 |
+
'start_char': token.start_char,
|
| 60 |
+
'end_char': token.end_char
|
| 61 |
+
}
|
| 62 |
+
word_annotations.append(wa)
|
| 63 |
+
|
| 64 |
+
sa = {
|
| 65 |
+
'sent_text': sent.text,
|
| 66 |
+
'word_annotations': word_annotations
|
| 67 |
+
}
|
| 68 |
+
if args.parse:
|
| 69 |
+
sa['constituency_parse'] = __get_constituency_parse(sent, nlp2)
|
| 70 |
+
|
| 71 |
+
sent_annotations.append(sa)
|
| 72 |
+
line_annotations.append({'sent_annotations': sent_annotations})
|
| 73 |
+
|
| 74 |
+
json_filename = os.path.splitext(file.name)[0] + '_parsed.json' if args.parse else '.json'
|
| 75 |
+
with open(json_filename, "w") as outfile:
|
| 76 |
+
json.dump(line_annotations, outfile, indent=4)
|
| 77 |
+
|
| 78 |
+
def __get_constituency_parse(sent, nlp):
|
| 79 |
+
try:
|
| 80 |
+
parse_doc = nlp(sent.text)
|
| 81 |
+
except:
|
| 82 |
+
return None
|
| 83 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
| 84 |
+
return "(ROOT " + " ".join(parse_trees) + ")"
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
parser = argparse.ArgumentParser(
|
| 88 |
+
prog='Tag BabyLM dataset',
|
| 89 |
+
description='Tag BabyLM dataset using Stanza')
|
| 90 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
| 91 |
+
nargs='+', help="Path to file(s)")
|
| 92 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
| 93 |
+
help="Include constituency parse")
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
rank = init_distributed_mode()
|
| 97 |
+
|
| 98 |
+
BATCH_SIZE = 1000
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 100 |
+
nlp1 = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma', package="default_accurate", use_gpu=True)
|
| 101 |
+
|
| 102 |
+
nlp2 = None
|
| 103 |
+
if args.parse:
|
| 104 |
+
nlp2 = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', package="default_accurate", use_gpu=True)
|
| 105 |
+
|
| 106 |
+
run_on_gpu(rank, args, tokenizer, nlp1, nlp2)
|
data/tag_single.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# single file can be split to some small files and run on different gpus
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("..")
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import glob
|
| 9 |
+
import tqdm
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import stanza
|
| 13 |
+
import json
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
|
| 16 |
+
def chunk_text(text, tokenizer, max_length=512):
|
| 17 |
+
tokens = tokenizer(text)['input_ids']
|
| 18 |
+
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
|
| 19 |
+
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
|
| 20 |
+
|
| 21 |
+
def init_distributed_mode():
|
| 22 |
+
dist.init_process_group(backend='nccl')
|
| 23 |
+
rank = dist.get_rank()
|
| 24 |
+
torch.cuda.set_device(rank) # 使用rank指定GPU
|
| 25 |
+
return rank
|
| 26 |
+
|
| 27 |
+
def process_single_file(file, rank, tokenizer, nlp1, nlp2):
|
| 28 |
+
print(f"GPU {rank}: Processing {file.name}")
|
| 29 |
+
lines = file.readlines()
|
| 30 |
+
|
| 31 |
+
# 根据行数划分任务
|
| 32 |
+
num_lines = len(lines)
|
| 33 |
+
num_gpus = dist.get_world_size()
|
| 34 |
+
|
| 35 |
+
lines_per_gpu = (num_lines + num_gpus - 1) // num_gpus
|
| 36 |
+
start_idx = rank * lines_per_gpu
|
| 37 |
+
end_idx = min(start_idx + lines_per_gpu, num_lines)
|
| 38 |
+
gpu_lines = lines[start_idx:end_idx]
|
| 39 |
+
|
| 40 |
+
line_batches = [gpu_lines[i:i + BATCH_SIZE] for i in range(0, len(gpu_lines), BATCH_SIZE)]
|
| 41 |
+
text_batches = [" ".join(l) for l in line_batches]
|
| 42 |
+
|
| 43 |
+
line_annotations = []
|
| 44 |
+
for text in tqdm.tqdm(text_batches, desc=f"GPU {rank}"):
|
| 45 |
+
text_chunks = chunk_text(text, tokenizer)
|
| 46 |
+
for chunk in text_chunks:
|
| 47 |
+
doc = nlp1(chunk)
|
| 48 |
+
sent_annotations = []
|
| 49 |
+
for sent in doc.sentences:
|
| 50 |
+
word_annotations = []
|
| 51 |
+
for token, word in zip(sent.tokens, sent.words):
|
| 52 |
+
wa = {
|
| 53 |
+
'id': word.id,
|
| 54 |
+
'text': word.text,
|
| 55 |
+
'lemma': word.lemma,
|
| 56 |
+
'upos': word.upos,
|
| 57 |
+
'xpos': word.xpos,
|
| 58 |
+
'feats': word.feats,
|
| 59 |
+
'start_char': token.start_char,
|
| 60 |
+
'end_char': token.end_char
|
| 61 |
+
}
|
| 62 |
+
word_annotations.append(wa)
|
| 63 |
+
|
| 64 |
+
sa = {
|
| 65 |
+
'sent_text': sent.text,
|
| 66 |
+
'word_annotations': word_annotations
|
| 67 |
+
}
|
| 68 |
+
if args.parse:
|
| 69 |
+
sa['constituency_parse'] = __get_constituency_parse(sent, nlp2)
|
| 70 |
+
|
| 71 |
+
sent_annotations.append(sa)
|
| 72 |
+
line_annotations.append({'sent_annotations': sent_annotations})
|
| 73 |
+
|
| 74 |
+
# 暂存不同GPU的输出
|
| 75 |
+
temp_filename = os.path.splitext(file.name)[0] + f'_rank{rank}.json'
|
| 76 |
+
with open(temp_filename, "w") as outfile:
|
| 77 |
+
json.dump(line_annotations, outfile, indent=4)
|
| 78 |
+
|
| 79 |
+
return temp_filename
|
| 80 |
+
|
| 81 |
+
def merge_files(temp_files, output_file):
|
| 82 |
+
merged_data = []
|
| 83 |
+
for file in temp_files:
|
| 84 |
+
with open(file, "r") as infile:
|
| 85 |
+
data = json.load(infile)
|
| 86 |
+
merged_data.extend(data)
|
| 87 |
+
os.remove(file) # 删除临时文件
|
| 88 |
+
|
| 89 |
+
with open(output_file, "w") as outfile:
|
| 90 |
+
json.dump(merged_data, outfile, indent=4)
|
| 91 |
+
|
| 92 |
+
def run_on_gpu(rank, args, tokenizer, nlp1, nlp2):
|
| 93 |
+
print(f"Running on Rank {rank}, using GPU {torch.cuda.current_device()}")
|
| 94 |
+
|
| 95 |
+
temp_files = []
|
| 96 |
+
if len(args.path) == 1:
|
| 97 |
+
temp_files.append(process_single_file(args.path[0], rank, tokenizer, nlp1, nlp2))
|
| 98 |
+
dist.barrier() # 等待所有进程完成处理
|
| 99 |
+
if rank == 0:
|
| 100 |
+
# 合并文件
|
| 101 |
+
final_output = os.path.splitext(args.path[0].name)[0] + '_merged.json'
|
| 102 |
+
merge_files(temp_files, final_output)
|
| 103 |
+
else:
|
| 104 |
+
files_per_gpu = len(args.path) // dist.get_world_size()
|
| 105 |
+
start_idx = rank * files_per_gpu
|
| 106 |
+
end_idx = start_idx + files_per_gpu if rank != dist.get_world_size() - 1 else len(args.path)
|
| 107 |
+
gpu_files = args.path[start_idx:end_idx]
|
| 108 |
+
|
| 109 |
+
for file in gpu_files:
|
| 110 |
+
process_single_file(file, rank, tokenizer, nlp1, nlp2)
|
| 111 |
+
|
| 112 |
+
def __get_constituency_parse(sent, nlp):
|
| 113 |
+
try:
|
| 114 |
+
parse_doc = nlp(sent.text)
|
| 115 |
+
except:
|
| 116 |
+
return None
|
| 117 |
+
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
|
| 118 |
+
return "(ROOT " + " ".join(parse_trees) + ")"
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
parser = argparse.ArgumentParser(
|
| 122 |
+
prog='Tag BabyLM dataset',
|
| 123 |
+
description='Tag BabyLM dataset using Stanza')
|
| 124 |
+
parser.add_argument('path', type=argparse.FileType('r'),
|
| 125 |
+
nargs='+', help="Path to file(s)")
|
| 126 |
+
parser.add_argument('-p', '--parse', action='store_true',
|
| 127 |
+
help="Include constituency parse")
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
rank = init_distributed_mode()
|
| 131 |
+
|
| 132 |
+
BATCH_SIZE = 1000
|
| 133 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 134 |
+
nlp1 = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma', package="default_accurate", use_gpu=True)
|
| 135 |
+
|
| 136 |
+
nlp2 = None
|
| 137 |
+
if args.parse:
|
| 138 |
+
nlp2 = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', package="default_accurate", use_gpu=True)
|
| 139 |
+
|
| 140 |
+
run_on_gpu(rank, args, tokenizer, nlp1, nlp2)
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1000.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Perplexity
|
| 2 |
+
12239897.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-10000.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Perplexity
|
| 2 |
+
100086656.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-11500.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Perplexity
|
| 2 |
+
86934072.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-1500.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Perplexity
|
| 2 |
+
221982080.0
|
perplexities/perplexity_results/Qwen2.5-0.5B/reverse_full/Qwen2.5-0.5B_seed1_test_reverse_full_checkpoint-2000.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Perplexity
|
| 2 |
+
389647168.0
|