| | |
| | |
| |
|
| | |
| | import sys |
| | sys.path.append("..") |
| |
|
| | from utils import PERTURBATIONS |
| | import argparse |
| | import pandas as pd |
| | import os |
| | from tqdm import tqdm |
| | from nltk import Tree |
| | from numpy.random import default_rng |
| |
|
| |
|
| | def get_span(tokens, sub_tokens): |
| | for i in range(len(tokens)): |
| | if tokens[i:i+len(sub_tokens)] == sub_tokens: |
| | start_idx, end_idx = i, i + len(sub_tokens) |
| | return start_idx, end_idx |
| |
|
| |
|
| | def extract_phrases(tree, categories): |
| | """Extract phrases that belong to the specified categories.""" |
| | results = [] |
| |
|
| | for subtree in tree.subtrees(): |
| | if subtree.label() in categories: |
| | words = subtree.leaves() |
| | if len(words) > 1: |
| | phrase = ' '.join(words) |
| | results.append((phrase, subtree.label())) |
| |
|
| | return results |
| |
|
| |
|
| | def process_file(file): |
| | """Read file, extract phrases and return results.""" |
| | results = [] |
| |
|
| | categories = ["NP", "VP", "ADJP", "ADVP", "PP"] |
| |
|
| | |
| | with open(file, 'r') as f: |
| | lines = f.readlines() |
| |
|
| | |
| | for i in tqdm(range(0, len(lines) - 1, 2)): |
| | sentence = lines[i].strip() |
| |
|
| | if len(sentence.split()) < 5: |
| | continue |
| |
|
| | tree_str = lines[i + 1].strip() |
| | tree = Tree.fromstring(tree_str) |
| |
|
| | phrases = extract_phrases(tree, categories) |
| | for phrase, category in phrases: |
| | if phrase in sentence: |
| | results.append((sentence, phrase, category)) |
| |
|
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser( |
| | prog='Get phrase spans for edge probing', |
| | description='Get spans of text from constituency parses for edge probing experiments') |
| | parser.add_argument('perturbation_type', |
| | default='all', |
| | const='all', |
| | nargs='?', |
| | choices=PERTURBATIONS.keys(), |
| | help='Perturbation function used to transform BabyLM dataset') |
| |
|
| | |
| | args = parser.parse_args() |
| |
|
| | |
| | perturbation_class = None |
| | if "reverse" in args.perturbation_type: |
| | perturbation_class = "reverse" |
| | elif "hop" in args.perturbation_type: |
| | perturbation_class = "hop" |
| | else: |
| | raise Exception("Perturbation class not implemented") |
| |
|
| | |
| | print("Extracting phrases from constituency parses") |
| | data = process_file( |
| | f"test_constituency_parses/{perturbation_class}_parses.test") |
| |
|
| | |
| | SAMPLE_SIZE = 10000 |
| | RANDOM_STATE = 62 |
| | rng = default_rng(RANDOM_STATE) |
| | rng.shuffle(data) |
| |
|
| | |
| | tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"] |
| | span_sample_data = [] |
| | print("Getting spans of tokens for constituents") |
| | for sentence, phrase, category in tqdm(data): |
| |
|
| | |
| | tokens = tokenizer.encode(sentence) |
| | if len(tokens) > 1024: |
| | continue |
| | sub_tokens = tokenizer.encode(phrase) |
| |
|
| | span = get_span(tokens, sub_tokens) |
| |
|
| | |
| | if span is None: |
| | sub_tokens = tokenizer.encode(" " + phrase) |
| | span = get_span(tokens, sub_tokens) |
| |
|
| | if span is not None: |
| | start_idx, end_idx = span |
| | rev_start_index, rev_end_index = len(tokens) - end_idx, len(tokens) - start_idx |
| | span_sample_data.append( |
| | (sentence, phrase, category, " ".join([str(t) for t in tokens]), |
| | start_idx, end_idx, rev_start_index, rev_end_index)) |
| |
|
| | |
| | sample_df = pd.DataFrame(data=span_sample_data, columns=[ |
| | "Sentence", "Phrase", "Category", |
| | "Sentence Tokens", "Start Index", "End Index", |
| | "Rev Start Index", "Rev End Index"]) |
| | final_sample_df = sample_df.groupby('Category', group_keys=False).apply( |
| | lambda x: x.sample(SAMPLE_SIZE // 5, random_state=RANDOM_STATE)) |
| | final_sample_df = final_sample_df.sample(frac=1).reset_index(drop=True) |
| |
|
| | |
| | phrases_directory = f"phrase_data/" |
| | if not os.path.exists(phrases_directory): |
| | os.makedirs(phrases_directory) |
| | phrases_file = phrases_directory + f"{perturbation_class}_phrase_data.csv" |
| | final_sample_df.to_csv(phrases_file, index=False) |
| |
|