""" WebDataset distributed utility functions, pipeline helper functions and sampler classes. """ from torch.utils.data import IterableDataset import torch import math import random import os import logging import braceexpand def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" if "No images in sample" in str(exn) or "Only one image in sample" in str(exn): return True if isinstance(exn, FileNotFoundError) or "FileNotFoundError" in str(type(exn)): if os.environ.get("RANK", "0") == "0": logging.warning(f"Handling webdataset FileNotFoundError: {exn}. Ignoring and continuing.") return True logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True # Distributed environment detection and shard allocation def pytorch_worker_info(group=None): """Return node and worker info for PyTorch and some distributed environments.""" rank = 0 world_size = 1 worker = 0 num_workers = 1 if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) else: try: import torch.distributed if torch.distributed.is_available() and torch.distributed.is_initialized(): group = group or torch.distributed.group.WORLD rank = torch.distributed.get_rank(group=group) world_size = torch.distributed.get_world_size(group=group) except ModuleNotFoundError: pass if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: worker = int(os.environ["WORKER"]) num_workers = int(os.environ["NUM_WORKERS"]) else: try: import torch.utils.data worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers except ModuleNotFoundError: pass return rank, world_size, worker, num_workers def is_multi_node_environment(): """ check if in a multi-process (world_size > 1) environment. """ try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): if dist.get_world_size() > 1: return True except Exception: pass world_size = int(os.environ.get("WORLD_SIZE", os.environ.get("SLURM_NTASKS", "1"))) nnodes = int(os.environ.get("NNODES", os.environ.get("SLURM_NNODES", "1"))) if nnodes > 1: return True return world_size > 1 def split_data_by_node(urls, strategy="interleaved"): """split shards between nodes, even if the data is stored locally, it is recommended to use it to avoid duplicate training.""" print('*'*80) print("split_data_by_node ing..................") gpus_per_node = torch.cuda.device_count() rank, world_size, worker, num_workers = pytorch_worker_info() print("rank: {}, world_size: {}, worker: {}, num_workers: {}, gpus_per_node: {}".format( rank, world_size, worker, num_workers, gpus_per_node)) node_rank = rank // gpus_per_node node_world_size = world_size // gpus_per_node if len(urls) < node_world_size: print(f"Warning: Only {len(urls)} shards but {node_world_size} nodes. " f"All nodes will use all shards to avoid empty assignment.") print(f"Node {node_rank} has {len(urls)} URLs of {len(urls)} total.") print('*'*80) return urls if strategy == "chunk": urls_per_node = math.ceil(len(urls) / node_world_size) start_idx = node_rank * urls_per_node end_idx = min(start_idx + urls_per_node, len(urls)) node_urls = urls[start_idx:end_idx] elif strategy == "interleaved": node_urls = urls[node_rank::node_world_size] elif strategy == "shuffled_chunk": shuffled_urls = random.sample(urls, len(urls)) urls_per_node = math.ceil(len(shuffled_urls) / node_world_size) start_idx = node_rank * urls_per_node end_idx = min(start_idx + urls_per_node, len(urls)) node_urls = shuffled_urls[start_idx:end_idx] else: raise ValueError(f"Unknown strategy {strategy}") print(f"Node {node_rank} has {len(node_urls)} URLs of {len(urls)} total.") print('*'*80) return node_urls def get_dataset_size(shards, estimated_sample_per_shard=1000): """estimate the dataset size, based on the number of shards.""" if ',' in shards: shards_list = [] for pattern in shards.split(','): pattern = pattern.strip() if not pattern: continue shards_list.extend(list(braceexpand.braceexpand(pattern))) else: shards_list = list(braceexpand.braceexpand(shards)) num_shards = len(shards_list) total_size = num_shards * estimated_sample_per_shard print(f"Estimating dataset size: {total_size} samples ({num_shards} shards * {estimated_sample_per_shard} samples/shard)") return total_size, num_shards # Pipeline helper functions (module level, supports pickle/spawn) def nodesplitter_identity(urls): return urls def handle_reconstruction_task(sample, handler=log_and_continue): in_key = None if "in.png" in sample: in_key = "in.png" elif "in.jpg" in sample: in_key = "in.jpg" out_key = None if "out.png" in sample: out_key = "out.png" elif "out.jpg" in sample: out_key = "out.jpg" if in_key and not out_key: if in_key == "in.png": sample["out.png"] = sample["in.png"] else: sample["out.jpg"] = sample["in.jpg"] return sample def extract_fields_to_tuple(sample, handler=log_and_continue): in_img = sample.get("in.png") or sample.get("in.jpg") out_img = sample.get("out.png") or sample.get("out.jpg") if out_img is None and in_img is not None: out_img = in_img sample_type = sample.get("type", None) return (in_img, out_img, sample_type) def identity_function(x, handler=log_and_continue): return x def has_input_image(sample): return "in.png" in sample or "in.jpg" in sample class WeightedRoundRobinSampler(IterableDataset): def __init__(self, pipelines, weights): super().__init__() if len(weights) != len(pipelines): raise ValueError(f"number of weights ({len(weights)}) must be equal to the number of pipelines ({len(pipelines)})") self.pipelines = pipelines self.weights = weights total_weight = sum(weights) normalized_weights = [w / total_weight for w in weights] max_decimal_places = max(len(str(w).split('.')[-1]) if '.' in str(w) else 0 for w in normalized_weights) scale_factor = 10 ** max_decimal_places int_weights = [int(w * scale_factor) for w in normalized_weights] def gcd(a, b): while b: a, b = b, a % b return a def gcd_list(nums): result = nums[0] for num in nums[1:]: result = gcd(result, num) return result common_divisor = gcd_list(int_weights) int_weights = [w // common_divisor for w in int_weights] self.sampling_sequence = [] for i, weight in enumerate(int_weights): self.sampling_sequence.extend([i] * weight) def __iter__(self): import itertools iterators = [iter(p) for p in self.pipelines] sequence_iter = itertools.cycle(self.sampling_sequence) active = [True] * len(iterators) while True: if not any(active): break idx = next(sequence_iter) if active[idx]: try: yield next(iterators[idx]) except StopIteration: active[idx] = False if not any(active): break continue class StrictProportionalBatchSampler(IterableDataset): """ a strictly proportional batch sampler (适用于 resampled=True) ensure that the samples in each batch are strictly allocated according to the weight ratio """ def __init__(self, pipelines, weights, batch_size): super().__init__() if len(weights) != len(pipelines): raise ValueError(f"number of weights ({len(weights)}) must be equal to the number of pipelines ({len(pipelines)})") self.pipelines = pipelines self.weights = weights self.batch_size = batch_size total_weight = sum(weights) normalized_weights = [w / total_weight for w in weights] self.samples_per_pipeline = [] float_counts = [batch_size * w for w in normalized_weights] int_counts = [round(c) for c in float_counts] current_sum = sum(int_counts) diff = batch_size - current_sum if diff != 0: errors = [(float_counts[i] - int_counts[i], i) for i in range(len(int_counts))] errors.sort(reverse=(diff > 0)) for _ in range(abs(diff)): _, idx = errors.pop(0) int_counts[idx] += 1 if diff > 0 else -1 self.samples_per_pipeline = int_counts weight_strs = [f"{w*100:.1f}%" for w in normalized_weights] sample_strs = [f"{count}" for count in self.samples_per_pipeline] actual_ratios = [f"{count/batch_size*100:.1f}%" for count in self.samples_per_pipeline] print(f"Strict proportional batch sampling enabled:") print(f" Target weights: {' : '.join(weight_strs)}") print(f" Actual samples per batch: {' : '.join(sample_strs)} (total={batch_size})") print(f" Actual ratios: {' : '.join(actual_ratios)}") def __iter__(self): import random as _random iterators = [iter(p) for p in self.pipelines] while True: batch_samples = [] for idx, count in enumerate(self.samples_per_pipeline): for _ in range(count): sample = next(iterators[idx]) batch_samples.append(sample) _random.shuffle(batch_samples) normalized_samples = [] for sample in batch_samples: if len(sample) == 3: normalized_samples.append((sample[0], sample[1], sample[2], None)) elif len(sample) == 4: normalized_samples.append(sample) else: raise ValueError(f"Unexpected sample length: {len(sample)}") batch_transposed = list(zip(*normalized_samples)) batch_results = [] for idx, items in enumerate(batch_transposed): if idx < 3: filtered_items = [item for item in items if item is not None] if len(filtered_items) != len(items): raise ValueError(f"Found None in tensor items at index {idx}") batch_results.append(torch.stack(list(filtered_items))) else: type_list = list(items) batch_results.append(type_list) yield tuple(batch_results)