| import webdataset as wds | |
| from torch.utils.data import IterableDataset | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| class MultiWebDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| urls, | |
| construct_collage_fn, | |
| shuffle_size=0, | |
| seed=0, | |
| decode_mode="pil", | |
| ): | |
| super().__init__() | |
| self.urls = urls | |
| self.shuffle_size = shuffle_size | |
| self.seed = seed | |
| self.decode_mode = decode_mode | |
| self.construct_collage_fn = construct_collage_fn | |
| def _to_rgb_np(self, img): | |
| if isinstance(img, Image.Image): | |
| return np.array(img.convert("RGB")) | |
| elif isinstance(img, np.ndarray): | |
| if img.ndim == 2: | |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| if img.ndim == 3 and img.shape[2] == 4: | |
| return img[:, :, :3] | |
| return img | |
| else: | |
| raise TypeError(f"Unsupported image type: {type(img)}") | |
| def _to_mask_np(self, img): | |
| if isinstance(img, Image.Image): | |
| m = np.array(img.convert("L")) | |
| elif isinstance(img, np.ndarray): | |
| if img.ndim == 3: | |
| m = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| else: | |
| m = img | |
| else: | |
| raise TypeError(f"Unsupported mask type: {type(img)}") | |
| m = (m > 127).astype(np.uint8) * 255 | |
| return m | |
| def __iter__(self): | |
| ds = wds.WebDataset(self.urls, shardshuffle=True, empty_check=False) | |
| if self.shuffle_size and self.shuffle_size > 0: | |
| ds = ds.shuffle(self.shuffle_size) | |
| ds = ds.decode("pil") | |
| ds = ds.rename( | |
| bg="bg.jpg", | |
| obj0="obj0.png", | |
| mask0="mask0.png", | |
| obj1="obj1.png", | |
| mask1="mask1.png", | |
| ) | |
| for sample in ds: | |
| bg = sample["bg"] | |
| obj0 = sample["obj0"] | |
| obj1 = sample["obj1"] | |
| mask0 = sample["mask0"] | |
| mask1 = sample["mask1"] | |
| bg_np = self._to_rgb_np(bg) | |
| obj0_np = self._to_rgb_np(obj0) | |
| obj1_np = self._to_rgb_np(obj1) | |
| mask0_np = self._to_mask_np(mask0) | |
| mask1_np = self._to_mask_np(mask1) | |
| collage = self.construct_collage_fn( | |
| bg_np, obj0_np, obj1_np, mask0_np, mask1_np | |
| ) | |
| yield collage | |