""" # you can run the following command to make DB2TOKCNT readable autopep8 --in-place --aggressive --aggressive finetune/scripts/parse_mixture.py This script is used to parse the mixture of the pretraining data input: path to the yaml file output: a megatron style data mixture string """ import os import sys import argparse import yaml import re EXAMPLE_LOG_STRING = """Zarr-based strategies will not be registered because of missing packages Counting tokens in ./mmap/example.bin 0%| | 0/597667 [00:00= 1: mixture_str += repeat_str( f"1 {mmap_path_without_ext} ", int(repeat_times)) else: # weight is less than 1 mixture_str += f"{repeat_times} {mmap_path_without_ext} " tokcnt = DB2TOKCNT[mmap_path] if isinstance(tokcnt, str): assert tokcnt.endswith("B"), f"invalid tokcnt: {tokcnt}" tokcnt = float(tokcnt.replace("B", "")) * 10**9 total_tokcnt += tokcnt * repeat_times else: assert isinstance(tokcnt, int), f"invalid tokcnt: {tokcnt}" total_tokcnt += tokcnt * repeat_times # total iter count total_iter = total_tokcnt / (cfg["GLOBAL_BATCH_SIZE"] * cfg["SEQ_LEN"]) # into string x.xxxB total_tokcnt /= 1e9 total_tokcnt = f"{total_tokcnt:.3f}B" return mixture_str, total_tokcnt, total_iter def parse_mixture_from_cfg(cfg): keys = list(cfg.keys()) # find keys ends with _ROUND rounds = [k for k in keys if k.endswith("_ROUND")] def repeat_str(s, n): return "".join([s for _ in range(n)]) total_tokcnt = 0 mixture_str = "" for r in rounds: repeat_times = float(r.replace("_ROUND", "")) mmap_paths = sorted(set(cfg[r])) for mmap_path in mmap_paths: mmap_path_without_ext = os.path.splitext(mmap_path)[0] tokcnt = DB2TOKCNT[mmap_path] if isinstance(tokcnt, str): assert tokcnt.endswith("B"), f"invalid tokcnt: {tokcnt}" tokcnt = float(tokcnt.replace("B", "")) * 10**9 total_tokcnt += tokcnt * repeat_times else: assert isinstance(tokcnt, int), f"invalid tokcnt: {tokcnt}" total_tokcnt += tokcnt * repeat_times mixture_str += f"{int(tokcnt * repeat_times)} {mmap_path_without_ext} " # total iter count total_iter = total_tokcnt / (cfg["GLOBAL_BATCH_SIZE"] * cfg["SEQ_LEN"]) # into string x.xxxB total_tokcnt /= 1e9 total_tokcnt = f"{total_tokcnt:.3f}B" return mixture_str, total_tokcnt, total_iter if __name__ == "__main__": args = parse_args() cfg = load_yaml(args.cfg) print(f"[INFO] Loaded cfg from {args.cfg}") TOKEN_COUNT_LOG_DIR = cfg["TOKEN_COUNT_LOG_DIR"] print(f"[INFO] TOKEN_COUNT_LOG_DIR: {TOKEN_COUNT_LOG_DIR}") get_tokcnts_from_logs(TOKEN_COUNT_LOG_DIR, by_billions=args.by_billions) print(f"[INFO] DB2TOKCNT reloaded from the logs in {TOKEN_COUNT_LOG_DIR}\n") mixture_str, total_tokcnt, total_iter = parse_mixture_from_cfg(cfg) print(f"[CRITICAL] DATA_PATH **(copy to the training script)**:\n{mixture_str}\n") print(f"[CRITICAL] TRAIN_ITERS **(copy to the training script)**:\n{total_iter}\n") print(f"[INFO] Total token count: {total_tokcnt}")