| """ |
| Outputs all 13-grams found in The Pile. |
| |
| Loops through all documents and uses the logic found in janitor.py to extract 13-grams. |
| We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the |
| next stage. We also include the current pile document_id with each ngram instance to allow the |
| filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline). |
| |
| We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes |
| resuming hard (and we had the storage). |
| |
| Arguments |
| --------- |
| --working_directory (-dir) |
| Directory containing the pile distribution. An "output" subdirectory will be created underneath |
| to store the bucketed 13-grams, checkpoint and done files. Default: current directory |
| --n_value (-n) |
| n value in n-gram, added for later use if ever needed. Default: 13 |
| --bucket_count (-buckets) |
| Number of file buckets to use when generating 13grams. Default: 500 |
| """ |
|
|
| import argparse |
| import glob |
| import json |
| import logging |
| import os |
| import pickle |
| import signal |
| import sys |
| from pathlib import Path |
| from signal import SIGINT |
|
|
| from tqdm import tqdm |
| from tqdm_multiprocess.logger import setup_logger_tqdm |
|
|
| from lm_eval.decontamination.archiver import Reader, TextArchive |
| from lm_eval.decontamination.janitor import Janitor, word_ngrams |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| terminate = False |
|
|
|
|
| def handler(signal_received, frame): |
| global terminate |
| terminate = True |
|
|
|
|
| def yield_pile(start_offsets=None, checkpoint_offset=None): |
| directory = "pile" |
|
|
| if not os.path.exists(directory): |
| print( |
| "We expect the pile archives to be in the 'pile' directory, but this was not found." |
| ) |
| raise FileNotFoundError("Pile directory not found.") |
|
|
| files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*")))) |
|
|
| pile_global_offset = 0 |
| start_file = 0 |
| if checkpoint_offset: |
| for file_i, start_offset in enumerate(start_offsets): |
| if start_offset > checkpoint_offset: |
| break |
|
|
| start_file = file_i |
| pile_global_offset = start_offset |
|
|
| for file_i, file in enumerate(files): |
| if file_i < start_file: |
| logger.info(f"Skipping file {file}") |
| continue |
| logger.info(f"Reading from pile file: {file}") |
| reader = Reader() |
| for document in reader.read(file): |
| yield (pile_global_offset, document) |
| pile_global_offset += 1 |
|
|
|
|
| |
| |
| |
| class Buckets: |
| def __init__(self, directory, num_buckets): |
| self.bucket_files = [ |
| os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets) |
| ] |
| self.buckets = list(map(TextArchive, self.bucket_files)) |
| self.checkpoint_file = os.path.join(directory, "bucket_offsets.ckpt") |
|
|
| if os.path.exists(self.checkpoint_file): |
| self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb")) |
| else: |
| self.bucket_offsets = [0 for i in range(len(self.buckets))] |
|
|
| for i, offset in enumerate(self.bucket_offsets): |
| bucket = self.buckets[i] |
| bucket.fh.seek(offset) |
| bucket.fh.truncate() |
|
|
| def add_data(self, key, value): |
| i = hash(key) % len(self.buckets) |
| bucket = self.buckets[i] |
| bucket.add_data(value) |
|
|
| def save_checkpoint(self): |
| for bucket in self.buckets: |
| bucket.fh.flush() |
|
|
| bucket_offsets = [bucket.fh.tell() for bucket in self.buckets] |
| pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb")) |
|
|
| def close_buckets(self): |
| for bucket in self.buckets: |
| bucket.commit() |
|
|
|
|
| def do_ngrams_in_buckets(n_value, working_directory, bucket_count): |
| pile_statistics = json.load(open("pile_statistics.json", "r", encoding="utf-8")) |
| pile_document_count = pile_statistics["Document Count"] |
| start_offsets = pile_statistics["File Start Offsets"] |
|
|
| output_directory = os.path.join(working_directory, "output") |
| os.makedirs(output_directory, exist_ok=True) |
|
|
| logger.info(f"Generating {n_value}-grams and bucketing.") |
|
|
| |
| done_file = os.path.join(output_directory, "ngram_buckets.done") |
| if os.path.exists(done_file): |
| logger.info("ngrams already generated and bucketed, skipping") |
| return |
|
|
| |
| checkpoint_file = os.path.join(working_directory, "pile_offset.ckpt") |
| if os.path.exists(checkpoint_file): |
| checkpoint_offset = pickle.load(open(checkpoint_file, "rb")) |
| iterate = True |
| else: |
| checkpoint_offset = 0 |
| iterate = False |
|
|
| logger.info(f"Starting at pile document index {checkpoint_offset}") |
| buckets = Buckets(output_directory, bucket_count) |
|
|
| janitor = Janitor() |
| batch_size = 1000 |
| batch_counter = 0 |
|
|
| with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress: |
| for offset, document in yield_pile(start_offsets, checkpoint_offset): |
| if iterate: |
| logger.info(f"Iterating to offset {checkpoint_offset} from {offset}") |
| progress.update(offset) |
| iterate = False |
|
|
| if offset < checkpoint_offset: |
| progress.update() |
|
|
| if terminate: |
| return |
| continue |
|
|
| if offset == checkpoint_offset: |
| progress.reset(total=pile_document_count) |
| progress.update(checkpoint_offset) |
|
|
| |
| if batch_counter == batch_size: |
| progress.update(batch_size) |
| batch_counter = 0 |
| buckets.save_checkpoint() |
| pickle.dump(offset, open(checkpoint_file, "wb")) |
| if terminate: |
| buckets.close_buckets() |
| return |
|
|
| ngrams = word_ngrams(janitor.normalize_string(document), n_value) |
| for ngram in ngrams: |
| buckets.add_data(ngram, f"{ngram} {offset}") |
|
|
| batch_counter += 1 |
|
|
| buckets.close_buckets() |
| Path(done_file).touch() |
|
|
|
|
| parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.") |
| parser.add_argument("-dir", "--working_directory", default="") |
| parser.add_argument("-n", "--n_value", type=int, default=13) |
| parser.add_argument("-buckets", "--bucket_count", type=int, default=500) |
|
|
| if __name__ == "__main__": |
| version = 1.00 |
| print(f"Running version {version}") |
|
|
| if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0": |
| print("Please run 'export PYTHONHASHSEED=0' before running generate.") |
| sys.exit() |
|
|
| |
| previous_signal_int = signal.signal(SIGINT, handler) |
|
|
| logfile_path = "ngrams.log" |
| setup_logger_tqdm(logfile_path) |
|
|
| args = parser.parse_args() |
| do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) |
|
|
| info_dict = {"title": "dataset ngrams", "ngram_size": 13} |
| info_dict_path = os.path.join(args.working_directory, "info.json") |
| json.dump(info_dict, open(info_dict_path, "w", encoding="utf-8")) |
|
|