PICS / datasets /webdataset.py
Hang Zhou
Upload folder using huggingface_hub
0103f17 verified
import webdataset as wds
from torch.utils.data import IterableDataset
from PIL import Image
import numpy as np
import cv2
class MultiWebDataset(IterableDataset):
def __init__(
self,
urls,
construct_collage_fn,
shuffle_size=0,
seed=0,
decode_mode="pil",
):
super().__init__()
self.urls = urls
self.shuffle_size = shuffle_size
self.seed = seed
self.decode_mode = decode_mode
self.construct_collage_fn = construct_collage_fn
def _to_rgb_np(self, img):
if isinstance(img, Image.Image):
return np.array(img.convert("RGB"))
elif isinstance(img, np.ndarray):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.ndim == 3 and img.shape[2] == 4:
return img[:, :, :3]
return img
else:
raise TypeError(f"Unsupported image type: {type(img)}")
def _to_mask_np(self, img):
if isinstance(img, Image.Image):
m = np.array(img.convert("L"))
elif isinstance(img, np.ndarray):
if img.ndim == 3:
m = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
m = img
else:
raise TypeError(f"Unsupported mask type: {type(img)}")
m = (m > 127).astype(np.uint8) * 255
return m
def __iter__(self):
ds = wds.WebDataset(self.urls, shardshuffle=True, empty_check=False)
if self.shuffle_size and self.shuffle_size > 0:
ds = ds.shuffle(self.shuffle_size)
ds = ds.decode("pil")
ds = ds.rename(
bg="bg.jpg",
obj0="obj0.png",
mask0="mask0.png",
obj1="obj1.png",
mask1="mask1.png",
)
for sample in ds:
bg = sample["bg"]
obj0 = sample["obj0"]
obj1 = sample["obj1"]
mask0 = sample["mask0"]
mask1 = sample["mask1"]
bg_np = self._to_rgb_np(bg)
obj0_np = self._to_rgb_np(obj0)
obj1_np = self._to_rgb_np(obj1)
mask0_np = self._to_mask_np(mask0)
mask1_np = self._to_mask_np(mask1)
collage = self.construct_collage_fn(
bg_np, obj0_np, obj1_np, mask0_np, mask1_np
)
yield collage