# tag.py # Author: Julie Kallini # For importing utils import sys sys.path.append("..") import pytest import glob import tqdm import os import argparse import stanza import json from transformers import AutoTokenizer # Define the function to chunk text def chunk_text(text, tokenizer, max_length=512): tokens = tokenizer(text)['input_ids'] chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)] return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks] # Test case for checking equivalence of original and parsed files test_all_files = sorted(glob.glob("babylm_data/babylm_*/*")) test_original_files = [f for f in test_all_files if ".json" not in f] test_json_files = [f for f in test_all_files if "_parsed.json" in f] test_cases = list(zip(test_original_files, test_json_files)) @pytest.mark.parametrize("original_file, json_file", test_cases) def test_equivalent_lines(original_file, json_file): # Read lines of file and remove all whitespace original_file = open(original_file) original_data = "".join(original_file.readlines()) original_data = "".join(original_data.split()) json_file = open(json_file) json_lines = json.load(json_file) json_data = "" for line in json_lines: for sent in line["sent_annotations"]: json_data += sent["sent_text"] json_data = "".join(json_data.split()) # Test equivalence assert (original_data == json_data) # Constituency parsing function def __get_constituency_parse(sent, nlp): # Try parsing the doc try: parse_doc = nlp(sent.text) except: return None # Get set of constituency parse trees parse_trees = [str(sent.constituency) for sent in parse_doc.sentences] # Join parse trees and add ROOT constituency_parse = "(ROOT " + " ".join(parse_trees) + ")" return constituency_parse # Main function if __name__ == "__main__": parser = argparse.ArgumentParser( prog='Tag BabyLM dataset', description='Tag BabyLM dataset using Stanza') parser.add_argument('path', type=argparse.FileType('r'), nargs='+', help="Path to file(s)") parser.add_argument('-p', '--parse', action='store_true', help="Include constituency parse") # Get args args = parser.parse_args() # Init Stanza NLP tools nlp1 = stanza.Pipeline( lang='en', processors='tokenize, pos, lemma', package="default_accurate", use_gpu=True) # If constituency parse is needed, init second Stanza parser if args.parse: nlp2 = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', package="default_accurate", use_gpu=True) BATCH_SIZE = 100 # Tokenizer for splitting long text tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # Iterate over BabyLM files for file in args.path: print(file.name) lines = file.readlines() # Strip lines and join text print("Concatenating lines...") lines = [l.strip() for l in lines] line_batches = [lines[i:i + BATCH_SIZE] for i in range(0, len(lines), BATCH_SIZE)] text_batches = [" ".join(l) for l in line_batches] # Iterate over lines in file and track annotations line_annotations = [] print("Segmenting and parsing text batches...") for text in tqdm.tqdm(text_batches): # Split the text into chunks if it exceeds the max length text_chunks = chunk_text(text, tokenizer) # Iterate over each chunk for chunk in text_chunks: # Tokenize text with stanza doc = nlp1(chunk) # Iterate over sentences in the line and track annotations sent_annotations = [] for sent in doc.sentences: # Iterate over words in the sentence and track annotations word_annotations = [] for token, word in zip(sent.tokens, sent.words): wa = { 'id': word.id, 'text': word.text, 'lemma': word.lemma, 'upos': word.upos, 'xpos': word.xpos, 'feats': word.feats, 'start_char': token.start_char, 'end_char': token.end_char } word_annotations.append(wa) # Track word annotation # Get constituency parse if needed if args.parse: constituency_parse = __get_constituency_parse(sent, nlp2) sa = { 'sent_text': sent.text, 'constituency_parse': constituency_parse, 'word_annotations': word_annotations, } else: sa = { 'sent_text': sent.text, 'word_annotations': word_annotations, } sent_annotations.append(sa) # Track sent annotation la = { 'sent_annotations': sent_annotations } line_annotations.append(la) # Track line annotation # Write annotations to file as a JSON print("Writing JSON outfile...") ext = '_parsed.json' if args.parse else '.json' json_filename = os.path.splitext(file.name)[0] + ext with open(json_filename, "w") as outfile: json.dump(line_annotations, outfile, indent=4)