Spaces:
Configuration error
Configuration error
| import sys | |
| import io | |
| import os | |
| import re | |
| import json | |
| import tarfile | |
| from functools import partial | |
| import webdataset as wds | |
| from webdataset import ResampledShards, DataPipeline, tarfile_to_samples | |
| from webdataset.filters import pipelinefilter | |
| from webdataset.tariterators import url_opener, group_by_keys | |
| from webdataset.handlers import reraise_exception | |
| from webdataset.gopen import gopen_schemes, gopen | |
| def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress | |
| """Return node and worker info for PyTorch and some distributed environments.""" | |
| rank = 0 | |
| world_size = 1 | |
| worker = 0 | |
| num_workers = 1 | |
| try: | |
| import torch.distributed | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| group = group or torch.distributed.group.WORLD | |
| rank = torch.distributed.get_rank(group=group) | |
| world_size = torch.distributed.get_world_size(group=group) | |
| except ModuleNotFoundError: | |
| pass | |
| try: | |
| import torch.utils.data | |
| worker_info = torch.utils.data.get_worker_info() | |
| if worker_info is not None: | |
| worker = worker_info.id | |
| num_workers = worker_info.num_workers | |
| except ModuleNotFoundError: | |
| pass | |
| return rank, world_size, worker, num_workers | |
| def pytorch_worker_seed(group=None): | |
| """Compute a distinct, deterministic RNG seed for each worker and node.""" | |
| rank, world_size, worker, num_workers = pytorch_worker_info(group=group) | |
| return rank * 1000 + worker | |
| def worker_seed_sat(group=None, seed=0): | |
| return pytorch_worker_seed(group=group) + seed * 23 | |
| class ConfiguredResampledShards(ResampledShards): | |
| def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True): | |
| from sat.helpers import print_rank0 | |
| try: | |
| from megatron.core.parallel_state import get_data_parallel_group | |
| group = get_data_parallel_group() | |
| print_rank0("Using megatron data parallel group.") | |
| except: | |
| from sat.mpu import get_data_parallel_group | |
| try: | |
| group = get_data_parallel_group() | |
| print_rank0("Using sat data parallel group.") | |
| except AssertionError: | |
| group = None | |
| print_rank0("No data parallel group is specified!") | |
| worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed) | |
| super().__init__(urls, nshards, worker_seed_sat_this, deterministic) | |
| class SimpleDistributedWebDataset(DataPipeline): | |
| def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000): | |
| # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle | |
| try: | |
| from sat.mpu import get_model_parallel_world_size | |
| if get_model_parallel_world_size() > 1: | |
| shuffle_buffer = 1 | |
| except Exception: | |
| pass | |
| super().__init__( | |
| ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly | |
| tarfile_to_samples(), | |
| wds.shuffle(shuffle_buffer), | |
| process_fn, | |
| ) | |
| def tar_file_iterator_with_meta( | |
| fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None | |
| ): | |
| """Iterate over tar file, yielding filename, content pairs for the given tar stream. | |
| :param fileobj: byte stream suitable for tarfile | |
| :param meta_names: key of different items in meta file | |
| :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") | |
| """ | |
| stream = tarfile.open(fileobj=fileobj, mode="r|*") | |
| data_dir, filename = fileobj.name.rsplit("/", 1) | |
| meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}} | |
| if meta_stream is None: | |
| meta_file_name = filename.split(".")[0] + ".meta.jsonl" | |
| meta_path = os.path.join(data_dir, meta_file_name) | |
| if os.path.exists(meta_path): | |
| meta_stream = open(meta_path, "r") | |
| else: | |
| meta_file_name = meta_stream.name | |
| if meta_stream is not None: | |
| for lineno, line in enumerate(meta_stream): | |
| meta_list = [] | |
| try: | |
| meta_list.append(json.loads(line)) | |
| except Exception as exn: | |
| from sat.helpers import print_rank0 | |
| print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") | |
| continue | |
| for item in meta_list: | |
| if not item["key"] in meta_data: | |
| meta_data[item["key"]] = {} | |
| for meta_name in meta_names: | |
| if meta_name in item: | |
| meta_data[item["key"]][meta_name] = item[meta_name] | |
| meta_stream.close() | |
| try: | |
| for tarinfo in stream: | |
| fname = tarinfo.name | |
| try: | |
| if not tarinfo.isreg(): | |
| continue | |
| if fname is None: | |
| continue | |
| if "/" not in fname and fname.startswith("__") and fname.endswith("__"): | |
| # skipping metadata for now | |
| continue | |
| if skip_meta is not None and re.match(skip_meta, fname): | |
| continue | |
| if fname.endswith(".txt") and suffix is not None: | |
| data = (stream.extractfile(tarinfo).read().decode() + suffix).encode() | |
| else: | |
| data = stream.extractfile(tarinfo).read() | |
| result = dict(fname=fname, data=data) | |
| yield result | |
| if fname.endswith(".id"): | |
| fid = fname.split(".")[0] | |
| if "-$#%@&" in fid: | |
| sfid = fid.split("-$#%@&")[0] | |
| else: | |
| sfid = fid | |
| meta_data_fid = meta_data.get(sfid, {}) | |
| for meta_name in meta_names: | |
| meta_fname = fid + "." + meta_name | |
| meta = meta_data_fid.get(meta_name, None) | |
| yield dict(fname=meta_fname, data=meta) | |
| stream.members = [] | |
| except Exception as exn: | |
| if hasattr(exn, "args") and len(exn.args) > 0: | |
| exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| except Exception as exn: | |
| print(exn) | |
| del stream | |
| def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception): | |
| """Expand a stream of open tar files into a stream of tar file contents. | |
| This returns an iterator over (filename, file_contents). | |
| """ | |
| for source in data: | |
| url = source["url"] | |
| try: | |
| assert isinstance(source, dict) | |
| assert "stream" in source | |
| for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): | |
| assert isinstance(sample, dict) and "data" in sample and "fname" in sample | |
| sample["__url__"] = url | |
| yield sample | |
| except Exception as exn: | |
| exn.args = exn.args + (source.get("stream"), source.get("url")) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| def url_opener( | |
| data, | |
| handler, | |
| **kw, | |
| ): | |
| """Open URLs and yield a stream of url+stream pairs. | |
| Args: | |
| data: iterator over dict(url=...) | |
| handler: exception handler. | |
| kw: keyword arguments for gopen.gopen. | |
| Yields: | |
| a stream of url+stream pairs. | |
| """ | |
| for sample in data: | |
| assert isinstance(sample, dict), sample | |
| assert "url" in sample | |
| url = sample["url"] | |
| try: | |
| stream = gopen(url, **kw) | |
| if hasattr(stream, "meta_stream"): | |
| meta_stream = stream.meta_stream | |
| del stream.meta_stream | |
| else: | |
| meta_stream = None | |
| sample.update(stream=stream, meta_stream=meta_stream) | |
| yield sample | |
| except Exception as exn: | |
| exn.args = exn.args + (url,) | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception): | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander_with_meta(streams, meta_names, handler) | |
| samples = group_by_keys(files, handler=handler) | |
| return samples | |
| class MetaDistributedWebDataset(DataPipeline): | |
| """WebDataset with meta information files | |
| Extra Format: | |
| in webdataset (tar), for each sample there is a '.id'; | |
| for each tar file, there is a '.meta.jsonl' file with the same name; | |
| The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'. | |
| """ | |
| def __init__( | |
| self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None | |
| ): | |
| # os.environ['WDS_SHOW_SEED'] = '1' | |
| import torch | |
| if torch.distributed.get_rank() == 0: | |
| if include_dirs is not None: # /webdatasets/A,/webdatasets/C | |
| other_paths = [] | |
| include_dirs = include_dirs.split(",") | |
| for include_dir in include_dirs: | |
| if "*" in include_dir: | |
| include_dir, n = include_dir.split("*") | |
| n = int(n) | |
| else: | |
| n = 1 | |
| for cur_dir, dirs, files in os.walk(include_dir): | |
| for f in files: | |
| if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0: | |
| # other_paths.append(os.path.join(cur_dir,f)) | |
| other_paths.extend([os.path.join(cur_dir, f)] * n) | |
| # print(f'Adding dataset paths {",".join(other_paths)}') | |
| from braceexpand import braceexpand | |
| if len(path) > 0: # not "" | |
| path = list(braceexpand(path)) + other_paths | |
| else: | |
| path = other_paths | |
| path = [path] | |
| else: | |
| path = [ | |
| None, | |
| ] | |
| torch.distributed.broadcast_object_list(path, src=0) | |
| path = path[0] | |
| tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names) | |
| tarfile_to_samples = pipelinefilter(tarfile_samples) | |
| # if model parallel, shuffle_buffer should be 1 to disable shuffling | |
| try: | |
| from sat.mpu import get_model_parallel_world_size | |
| if get_model_parallel_world_size() > 1: | |
| shuffle_buffer = 1 | |
| except Exception: | |
| pass | |
| super().__init__( | |
| ConfiguredResampledShards(path, seed, nshards=nshards), | |
| tarfile_to_samples(), | |
| wds.shuffle(shuffle_buffer), | |
| process_fn, | |
| ) | |
| # rclone support | |
| from webdataset.gopen import Pipe | |
| def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32): | |
| """Open a URL with `curl`. | |
| :param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured. | |
| :param mode: file mode | |
| :param bufsize: buffer size | |
| """ | |
| url = url.replace("rclone://", "") | |
| if mode[0] == "r": | |
| cmd = f"rclone cat '{url}'" | |
| return Pipe( | |
| cmd, | |
| mode=mode, | |
| shell=True, | |
| bufsize=bufsize, | |
| ignore_status=[141, 23], | |
| ) # skipcq: BAN-B604 | |
| elif mode[0] == "w": | |
| cmd = f"rclone cp - '{url}'" | |
| return Pipe( | |
| cmd, | |
| mode=mode, | |
| shell=True, | |
| bufsize=bufsize, | |
| ignore_status=[141, 26], | |
| ) # skipcq: BAN-B604 | |
| else: | |
| raise ValueError(f"{mode}: unknown mode") | |
| def gopen_boto3(url, mode="rb", bufsize=8192 * 2): | |
| """Open a URL with boto3 API. | |
| :param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured. | |
| :param mode: file mode | |
| :param bufsize: buffer size | |
| """ | |
| import boto3 | |
| # boto3.set_stream_logger('botocore', level='DEBUG') | |
| if url.startswith("boto3://"): | |
| url = url.replace("boto3://", "") | |
| need_meta = False | |
| else: | |
| url = url.replace("metaboto3://", "") | |
| need_meta = True | |
| endpoint_url = os.environ.get("S3_ENDPOINT_URL", None) | |
| access_key = os.environ.get("S3_ACCESS_KEY_ID", None) | |
| secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None) | |
| if mode[0] == "r": | |
| s3_client = boto3.client( | |
| "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key | |
| ) | |
| bucket, key = url.split("/", 1) | |
| if need_meta: | |
| # download a meta json | |
| meta_file_key = key.split(".")[0] + ".meta.jsonl" | |
| meta_stream = io.BytesIO() | |
| s3_client.download_fileobj(bucket, meta_file_key, meta_stream) | |
| meta_stream.seek(0) | |
| meta_stream.name = meta_file_key | |
| else: | |
| meta_stream = None | |
| # data tar stream | |
| response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional | |
| response["Body"].name = key # actually not used | |
| response["Body"].meta_stream = meta_stream | |
| return response["Body"] | |
| else: | |
| raise ValueError(f"{mode}: unknown mode") | |
| gopen_schemes["rclone"] = gopen_rclone | |
| gopen_schemes["boto3"] = gopen_boto3 | |
| gopen_schemes["metaboto3"] = gopen_boto3 | |