import webdataset as wds from torch.utils.data import IterableDataset from PIL import Image import numpy as np import cv2 class MultiWebDataset(IterableDataset): def __init__( self, urls, construct_collage_fn, shuffle_size=0, seed=0, decode_mode="pil", ): super().__init__() self.urls = urls self.shuffle_size = shuffle_size self.seed = seed self.decode_mode = decode_mode self.construct_collage_fn = construct_collage_fn def _to_rgb_np(self, img): if isinstance(img, Image.Image): return np.array(img.convert("RGB")) elif isinstance(img, np.ndarray): if img.ndim == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if img.ndim == 3 and img.shape[2] == 4: return img[:, :, :3] return img else: raise TypeError(f"Unsupported image type: {type(img)}") def _to_mask_np(self, img): if isinstance(img, Image.Image): m = np.array(img.convert("L")) elif isinstance(img, np.ndarray): if img.ndim == 3: m = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) else: m = img else: raise TypeError(f"Unsupported mask type: {type(img)}") m = (m > 127).astype(np.uint8) * 255 return m def __iter__(self): ds = wds.WebDataset(self.urls, shardshuffle=True, empty_check=False) if self.shuffle_size and self.shuffle_size > 0: ds = ds.shuffle(self.shuffle_size) ds = ds.decode("pil") ds = ds.rename( bg="bg.jpg", obj0="obj0.png", mask0="mask0.png", obj1="obj1.png", mask1="mask1.png", ) for sample in ds: bg = sample["bg"] obj0 = sample["obj0"] obj1 = sample["obj1"] mask0 = sample["mask0"] mask1 = sample["mask1"] bg_np = self._to_rgb_np(bg) obj0_np = self._to_rgb_np(obj0) obj1_np = self._to_rgb_np(obj1) mask0_np = self._to_mask_np(mask0) mask1_np = self._to_mask_np(mask1) collage = self.construct_collage_fn( bg_np, obj0_np, obj1_np, mask0_np, mask1_np ) yield collage