| import copy |
| import io |
| import json |
| import logging |
| import multiprocessing |
| import os |
| import subprocess |
| import sys |
| import time |
| from itertools import cycle, islice |
|
|
| import fsspec |
| import numpy as np |
| import torch |
|
|
| from typing import List, Optional |
| from tqdm import tqdm |
|
|
| from open_lm.distributed import is_master |
|
|
|
|
| def remote_sync_s3(local_dir, remote_dir): |
| |
| result = subprocess.run( |
| ["aws", "s3", "sync", local_dir, remote_dir, "--exclude", "*epoch_latest.pt"], |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| ) |
| if result.returncode != 0: |
| logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") |
| return False |
|
|
| logging.info(f"Successfully synced with S3 bucket") |
| return True |
|
|
|
|
| def remote_sync_fsspec(local_dir, remote_dir): |
| |
| a = fsspec.get_mapper(local_dir) |
| b = fsspec.get_mapper(remote_dir) |
|
|
| for k in a: |
| |
| if "epoch_latest.pt" in k: |
| continue |
|
|
| logging.info(f"Attempting to sync {k}") |
| if k in b and len(a[k]) == len(b[k]): |
| logging.debug(f"Skipping remote sync for {k}.") |
| continue |
|
|
| try: |
| logging.info(f"Successful sync for {k}.") |
| b[k] = a[k] |
| except Exception as e: |
| logging.info(f"Error during remote sync for {k}: {e}") |
| return False |
|
|
| return True |
|
|
|
|
| def remote_sync(local_dir, remote_dir, protocol): |
| logging.info("Starting remote sync.") |
| if protocol == "s3": |
| return remote_sync_s3(local_dir, remote_dir) |
| elif protocol == "fsspec": |
| return remote_sync_fsspec(local_dir, remote_dir) |
| else: |
| logging.error("Remote protocol not known") |
| return False |
|
|
|
|
| def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol, max_retries=6): |
| for i in range(max_retries): |
| time.sleep(sync_every * 2**i) |
| success = remote_sync(local_dir, remote_dir, protocol) |
| if success: |
| return True |
|
|
| return False |
|
|
|
|
| def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): |
| while True: |
| remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol) |
|
|
|
|
| def start_sync_process(sync_every, local_dir, remote_dir, protocol): |
| p = multiprocessing.Process( |
| target=keep_running_remote_sync, |
| args=(sync_every, local_dir, remote_dir, protocol), |
| ) |
| return p |
|
|
|
|
| def terminate_sync_process(p: multiprocessing.Process): |
| if p is not None and p.is_alive(): |
| logging.info(f"Terminating remote sync process.") |
| p.terminate() |
|
|
|
|
| |
| def pt_save(pt_obj, file_path): |
| of = fsspec.open(file_path, "wb") |
| with of as f: |
| torch.save(pt_obj, file_path) |
|
|
|
|
| def _pt_load_s3_cp(file_path, map_location=None): |
| cmd = f"aws s3 cp {file_path} -" |
| proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| stdout, stderr = proc.communicate() |
| if proc.returncode != 0: |
| raise Exception(f"Failed to fetch model from s3. stderr: {stderr.decode()}") |
| return torch.load(io.BytesIO(stdout), map_location=map_location) |
|
|
|
|
| def pt_load(file_path, map_location=None): |
| if file_path.startswith("s3"): |
| logging.info("Loading remote checkpoint, which may take a bit.") |
| return _pt_load_s3_cp(file_path, map_location) |
| of = fsspec.open(file_path, "rb") |
| with of as f: |
| out = torch.load(f, map_location=map_location) |
| return out |
|
|
|
|
| def check_exists(file_path): |
| try: |
| with fsspec.open(file_path): |
| pass |
| except FileNotFoundError: |
| return False |
| return True |
|
|
|
|
| def get_metadata_file(path, shard_shuffle_seed=None): |
| of = fsspec.open(path, "rb") |
| with of as f: |
| out = f.read() |
| out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]] |
| if shard_shuffle_seed is not None: |
| rng_gen = np.random.default_rng(shard_shuffle_seed) |
| rng_gen.shuffle(out) |
| return out |
|
|
|
|
| def get_shards_for_chunk(num_samples, chunk, path, shard_shuffle_seed): |
| """Function to get a chunk of shards to train on. |
| |
| Chunks are groups of shards with samples roughly equal to the number of samples |
| that will be seen during training. This function uses the dataset manifest |
| to split the shards into chunks, and assign shards to each chunk. |
| """ |
| metadata = get_metadata_file(path, shard_shuffle_seed=shard_shuffle_seed) |
| shard_list = [] |
| curr_shard_list = [] |
| chunk_count_list = [] |
| curr_chunk_count = 0 |
| for m in metadata: |
| try: |
| curr_chunk_count += m["num_sequences"] |
| except KeyError: |
| curr_chunk_count += m["num_chunks"] |
|
|
| curr_shard_list.append(m["shard"]) |
| if curr_chunk_count >= num_samples: |
| shard_list.append(curr_shard_list) |
| chunk_count_list.append(curr_chunk_count) |
| curr_shard_list = [] |
| curr_chunk_count = 0 |
|
|
| |
| if len(curr_shard_list) > 0: |
| shard_list.append(curr_shard_list) |
| chunk_count_list.append(curr_chunk_count) |
|
|
| return ( |
| shard_list[chunk % len(shard_list)], |
| chunk_count_list[chunk % len(chunk_count_list)], |
| ) |
|
|
|
|
| def enough_shards(shard_lists: List[List[str]], min_shards_needed: int): |
| for sl in shard_lists: |
| if len(sl) < min_shards_needed: |
| return False |
| return True |
|
|
|
|
| def enough_samples(num_samples_per_source: List[List[int]], needed_samples_per_source: List[int]): |
| for i, number_per_shard in enumerate(num_samples_per_source): |
| if sum(number_per_shard) < needed_samples_per_source[i]: |
| return False |
| return True |
|
|
|
|
| def source_exhausted(paths, shard_list_per_source): |
| for i, source in enumerate(paths): |
| data = get_metadata_file(source) |
| if len(data) < len(shard_list_per_source[i]): |
| return True |
| return False |
|
|
|
|
| def count_small_shards(path, ratio=0.9): |
| """Count the number of shards with significantly fewer sequences than the largest shard. |
| |
| Small shards are defined as those that have size less than a ratio (default 90%) of the size of the largest shard. |
| """ |
| shard_sizes = [] |
| data = get_metadata_file(path) |
| for item in data: |
| try: |
| shard_sizes.append(item["num_sequences"]) |
| except KeyError: |
| shard_sizes.append(item["num_chunks"]) |
|
|
| shard_sizes = np.array(shard_sizes) |
|
|
| return np.sum(shard_sizes < ratio * max(shard_sizes)) |
|
|
|
|
| def are_sources_imbalanced_with_each_other(paths, ratio=2): |
| median_shard_size_per_source = [] |
| for p in paths: |
| shard_sizes = [] |
| data = get_metadata_file(p) |
| for item in data: |
| try: |
| shard_sizes.append(item["num_sequences"]) |
| except KeyError: |
| shard_sizes.append(item["num_chunks"]) |
|
|
| median_shard_size_per_source.append(np.median(shard_sizes)) |
|
|
| return max(median_shard_size_per_source) > ratio * min(median_shard_size_per_source) |
|
|
|
|
| def log_num_checkpoints(total_steps, args): |
| """Log the number of checkpoints that will be made. |
| |
| This function counts the number of checkpoints to be made, and logs that number, printing out a warning if that |
| number is different than expected. |
| """ |
|
|
| steps_done = 0 |
| tokens_seen = 0 |
| next_shard_per_source = [0 for _ in range(len(args.dataset_manifest))] if args.dataset_manifest is not None else 0 |
| checkpoints_made = 0 |
|
|
| if is_master(args): |
| logging.info("Precounting number of steps / tokens seen per checkpoint:") |
|
|
| while steps_done < total_steps: |
| _, num_samples_per_source, next_shard_per_source = get_string_for_epoch( |
| args.train_num_samples, |
| next_shard_per_source, |
| args.dataset_manifest, |
| args.train_data_mix_weights, |
| args.workers, |
| args.world_size, |
| multi_epoch=args.multiple_data_passes, |
| shard_shuffle_seed=args.shard_shuffle_seed, |
| ) |
| steps_epoch = sum( |
| [(n // (args.workers * args.global_batch_size)) * args.workers for n in num_samples_per_source] |
| ) |
| steps_done += steps_epoch |
| if steps_done > total_steps: |
| steps_done = total_steps |
| tokens_seen = steps_done * args.global_batch_size * args.seq_len |
| checkpoints_made += 1 |
|
|
| if is_master(args): |
| logging.info(f"==> Checkpoint {checkpoints_made}, steps {steps_done}, tokens seen {tokens_seen}") |
|
|
| if is_master(args): |
| logging.info( |
| f"Number of checkpoints to be made: {checkpoints_made}." |
| f"Number will be greater in case of unexpected failures leading to the use of more shards" |
| ) |
|
|
| if checkpoints_made != args.epochs: |
| logging.warning( |
| f"{args.epochs} were requested, but {checkpoints_made} will be made. This behavior is a best effort in " |
| f"checkpointing for the desired amount of epochs, and depends on the number of workers and gpus used, " |
| f"as well as the size of the shards themselves." |
| ) |
|
|
| return |
|
|
|
|
| def get_string_for_epoch( |
| num_samples: int, |
| starting_points: List[int], |
| paths: List[str], |
| weights: Optional[List[float]], |
| num_workers_per_gpu: int, |
| world_size: int, |
| multi_epoch=False, |
| shard_shuffle_seed=None, |
| ): |
| """See _single_epoch_string for full docstring.""" |
| if multi_epoch: |
| return _multi_epoch_string( |
| num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed |
| ) |
| else: |
| return _single_epoch_string( |
| num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed |
| ) |
|
|
|
|
| def _multi_epoch_string( |
| num_samples: int, |
| starting_shard_per_source: List[int], |
| paths: List[str], |
| weights: Optional[List[float]], |
| num_workers_per_gpu: int, |
| world_size: int, |
| shard_shuffle_seed: Optional[int], |
| ): |
| """Return the string for training the shards, while allowing multiple passes over the dataset.""" |
|
|
| num_sources = len(paths) |
| total_shards_per_source = [len(get_metadata_file(p, shard_shuffle_seed=None)) for p in paths] |
| pass_idx = starting_shard_per_source[0] // total_shards_per_source[0] |
|
|
| assert all( |
| [starting_shard_per_source[i] // total_shards_per_source[i] == pass_idx for i in range(num_sources)] |
| ), "Passes across sources are not synced." |
|
|
| retries = 3 |
|
|
| while retries > 0: |
| try: |
| starting_shard_per_source_single = [ |
| starting_shard_per_source[i] % total_shards_per_source[i] for i in range(num_sources) |
| ] |
| shard_strings_per_source, num_samples_per_source, next_shard_per_source = _single_epoch_string( |
| num_samples=num_samples, |
| starting_shard_per_source=starting_shard_per_source_single, |
| paths=paths, |
| weights=weights, |
| num_workers_per_gpu=num_workers_per_gpu, |
| world_size=world_size, |
| shard_shuffle_seed=shard_shuffle_seed + pass_idx if shard_shuffle_seed is not None else None, |
| ) |
| next_shard_per_source = [ |
| next_shard_per_source[i] + pass_idx * total_shards_per_source[i] for i in range(num_sources) |
| ] |
| return shard_strings_per_source, num_samples_per_source, next_shard_per_source |
| except IndexError as e: |
| |
| pass_idx += 1 |
| starting_shard_per_source = [pass_idx * total_shards_per_source[i] for i in range(num_sources)] |
| retries -= 1 |
|
|
| raise ValueError( |
| "Multiple passes over the dataset did not allow for a valid shard string to be created. Try decreasing the number of tokens between checkpoints." |
| ) |
|
|
|
|
| def _single_epoch_string( |
| num_samples: int, |
| starting_shard_per_source: List[int], |
| paths: List[str], |
| weights: Optional[List[float]], |
| num_workers_per_gpu: int, |
| world_size: int, |
| shard_shuffle_seed: Optional[int], |
| ): |
| """Retrieve shards to train on for a particular checkpoint. |
| |
| Currently only a single source is fully supported yet. |
| |
| Args: |
| num_samples: Total number of samples required. |
| starting_shard_per_source: First shard per source that has not been consumed yet. |
| paths: Paths to source manifests. |
| weights: Weighting between sources. If None, it is assumed to be uniform. |
| num_workers_per_gpu: Number of workers per gpu process. |
| world_size: Total number of gpus used for training. |
| shard_shuffle_seed: Seed to shuffle shards before checkpoint assignment |
| """ |
|
|
| num_sources = len(paths) |
|
|
| if num_sources > 1: |
| logging.warning( |
| "Multiple sources are not supported fully as of now. It is advised to combine the data into a single " |
| "source, by using datapreprocess/ray/tokenize_shuffle.py. Best effort will be done to mix data at the " |
| "desired ratio." |
| ) |
| if are_sources_imbalanced_with_each_other(paths): |
| logging.warning( |
| "Sources contain highly imbalanced shards (largest median shard size of a source is >2x the smallest " |
| "median size of a source). This will lead to deteriorated performance (less frequent checkpoints, " |
| "data being skipped, and inaccurate mixing). It is STRONGLY advised to combine into one source." |
| ) |
|
|
| for path in paths: |
| num_small_shards = count_small_shards(path) |
| if num_small_shards > 0: |
| logging.warning( |
| f"Source defined by {path} contains {num_small_shards} shards that are smaller than 90% the size of " |
| f"the largest shard. These shards might cause deterioration in performance, with more samples being " |
| f"skipped than necessary. It is advised to make the shards more uniform." |
| ) |
|
|
| if weights is None: |
| weights = [1.0 / num_sources for _ in range(num_sources)] |
|
|
| assert len(weights) == num_sources, "One weight is needed per source." |
|
|
| needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)] |
|
|
| manifests = [get_metadata_file(path, shard_shuffle_seed=shard_shuffle_seed) for path in paths] |
|
|
| shard_strings_per_source = [] |
| next_shard_per_source = copy.deepcopy(starting_shard_per_source) |
| shard_list_per_source = [[] for _ in range(num_sources)] |
| num_samples_per_source = [[] for _ in range(num_sources)] |
|
|
| total_num_workers = num_workers_per_gpu * world_size |
| while not enough_shards(shard_list_per_source, total_num_workers) or not enough_samples( |
| num_samples_per_source, needed_samples_per_source |
| ): |
| try: |
| for i in range(num_sources): |
| |
| shard_name = manifests[i][next_shard_per_source[i]]["shard"] |
| try: |
| num_samples_shard = manifests[i][next_shard_per_source[i]]["num_sequences"] |
| except KeyError: |
| num_samples_shard = manifests[i][next_shard_per_source[i]]["num_chunks"] |
|
|
| shard_list_per_source[i].append(shard_name) |
| num_samples_per_source[i].append(num_samples_shard) |
|
|
| next_shard_per_source[i] += 1 |
|
|
| except IndexError as e: |
| logging.error( |
| "Number of shards requested for a single epoch is more than the number of shards available. This means " |
| "that the amount of data requested to train on is more than the dataloader can serve. This can either " |
| "happen because there are not enough data to begin with, or data being skipped due to rounding errors. " |
| "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will " |
| "allow for better use of the dataset." |
| ) |
| raise e |
|
|
| for i in range(num_sources): |
| |
| |
| |
| |
| |
| |
| num_multiples = len(shard_list_per_source[i]) // total_num_workers |
|
|
| shard_list_per_source[i] = shard_list_per_source[i][: num_multiples * total_num_workers] |
| num_samples_per_source[i] = num_samples_per_source[i][: num_multiples * total_num_workers] |
|
|
| |
| next_shard_per_source[i] = starting_shard_per_source[i] + len(shard_list_per_source[i]) |
|
|
| num_samples_per_source = [sum(n) for n in num_samples_per_source] |
|
|
| for i, source_path in enumerate(paths): |
| |
| shard_list_source = shard_list_per_source[i] |
| shard_root_source = "/".join(source_path.split("/")[:-1]) + "/" |
| if len(shard_list_source) == 1: |
| shard_string_source = shard_root_source + shard_list_source[0] + ".tar" |
| else: |
| shard_string_source = shard_root_source + "{" + ",".join(shard_list_source) + "}.tar" |
| if source_path.startswith("s3"): |
| shard_string_source = f"pipe:aws s3 cp {shard_string_source} -" |
| shard_strings_per_source.append(shard_string_source) |
|
|
| return shard_strings_per_source, num_samples_per_source, next_shard_per_source |
|
|