Diffusers
Safetensors
icip_source_2 / midi /data /multi_object.py
hansQAQ's picture
Upload folder using huggingface_hub
278bf35 verified
import json
import os
import random
from dataclasses import dataclass, field
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from ..utils.config import parse_structured
from ..utils.typing import *
def _parse_object_list_single(object_list_path: str):
all_objects = []
if object_list_path.endswith(".json"):
with open(object_list_path) as f:
all_objects = json.loads(f.read())
else:
raise NotImplementedError
return all_objects
def _parse_object_list(object_list_path: Union[str, List[str]]):
all_objects = []
if isinstance(object_list_path, str):
object_list_path = [object_list_path]
for object_list_path_ in object_list_path:
all_objects += _parse_object_list_single(object_list_path_)
return all_objects
def _parse_scene_list_single(scene_list_path: str, root_data_dir: str):
all_scenes = []
if scene_list_path.endswith(".json"):
with open(scene_list_path) as f:
for p in json.loads(f.read()):
all_scenes.append(os.path.join(root_data_dir, p))
elif scene_list_path.endswith(".txt"):
with open(scene_list_path) as f:
for p in f.readlines():
p = p.strip()
all_scenes.append(os.path.join(root_data_dir, p))
else:
raise NotImplementedError
return all_scenes
def _parse_scene_list(
scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]]
):
all_scenes = []
if isinstance(scene_list_path, str):
scene_list_path = [scene_list_path]
if isinstance(root_data_dir, str):
root_data_dir = [root_data_dir]
for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir):
all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_)
return all_scenes
def random_morphological_transform(
mask: torch.Tensor, max_kernel_size: int = 5, p_dilation: float = 0.5
) -> torch.Tensor:
"""
Randomly dilate or erode the mask
:param mask: [H, W] 0-1 mask
:param max_kernel_size: Maximum kernel size; controls the scale of the morphological operation
:param p_dilation: Probability of dilation; if random sample > p_dilation, erosion is performed
:return: Perturbed mask
"""
mask_np = mask.cpu().numpy().astype(np.uint8)
kernel_size = np.random.randint(1, max_kernel_size + 1)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones((kernel_size, kernel_size), np.uint8)
if np.random.rand() < p_dilation:
mask_np = cv2.dilate(mask_np, kernel, iterations=1)
else:
mask_np = cv2.erode(mask_np, kernel, iterations=1)
return torch.from_numpy(mask_np).float().to(mask.device)
@dataclass
class MultiObjectDataModuleConfig:
scene_list: Any = ""
object_list: Any = ""
# Surface
surface_root_dir: Any = ""
surface_suffix: str = "npy"
num_surface_samples_per_object: int = 20480
return_scene: bool = False
max_num_instances: Optional[int] = None
padding: bool = True
num_instances_per_batch: Optional[int] = 10
# Image input
image_root_dir: Any = ""
image_prefix: Any = "render"
image_suffix: str = "webp"
idmap_prefix: str = "semantic"
idmap_suffix: str = "png"
background_color: Union[str, float] = "white"
image_names: List[str] = field(default_factory=lambda: [])
height: int = 768
width: int = 768
use_scene_image: bool = True
remove_scene_bg: bool = False
# Data processing
skip_small_object: bool = False
small_image_proportion: float = 0.005 # (16/224)^2
## Mask perturbation
morph_perturb: bool = False
max_kernel_size: int = 5
p_dilation: float = 0.5
return_crop_padded: bool = False
height_crop_padded: int = 224
width_crop_padded: int = 224
# Mix data
do_mix: bool = False
do_mix_prob: float = 0.5
mix_length: int = 80000
mix_scene_list: str = ""
mix_image_root_dir: str = ""
mix_surface_root_dir: str = ""
mix_surface_suffix: str = "npy"
mix_image_prefix: str = "render_opaque"
mix_image_names: List[str] = field(default_factory=lambda: [])
mix_image_suffix: str = "webp"
train_indices: Optional[Tuple[Any, Any]] = None
val_indices: Optional[Tuple[Any, Any]] = None
test_indices: Optional[Tuple[Any, Any]] = None
repeat: int = 1
batch_size: int = 1
eval_batch_size: int = 1
num_workers: int = 16
class MultiObjectDataset(Dataset):
def __init__(self, cfg: Any, split: str = "train") -> None:
super().__init__()
assert split in ["train", "val", "test"]
self.cfg: MultiObjectDataModuleConfig = cfg
self.all_scenes = _parse_scene_list(
self.cfg.scene_list, self.cfg.surface_root_dir
)
self.all_objects = _parse_object_list(self.cfg.object_list)
if len(self.all_scenes) != len(self.all_objects):
raise ValueError(
f"Number of scenes and objects must be the same, got {len(self.all_scenes)} scenes and {len(self.all_objects)} object lists."
)
self.all_images = _parse_scene_list(
self.cfg.scene_list, self.cfg.image_root_dir
)
self.split = split
self.indices = []
if self.split == "train" and self.cfg.train_indices is not None:
self.indices = (self.cfg.train_indices[0], self.cfg.train_indices[1])
elif self.split == "val" and self.cfg.val_indices is not None:
self.indices = (self.cfg.val_indices[0], self.cfg.val_indices[1])
elif self.split == "test" and self.cfg.test_indices is not None:
self.indices = (self.cfg.test_indices[0], self.cfg.test_indices[1])
else:
self.indices = (0, len(self.all_scenes))
repeat = self.cfg.repeat if self.split == "train" else 1
self.all_scenes = self.all_scenes[self.indices[0] : self.indices[1]] * repeat
self.all_objects = self.all_objects[self.indices[0] : self.indices[1]] * repeat
self.all_images = self.all_images[self.indices[0] : self.indices[1]] * repeat
if self.cfg.do_mix:
self.mix_all_scenes = _parse_scene_list(
self.cfg.mix_scene_list, self.cfg.mix_surface_root_dir
)[: self.cfg.mix_length]
self.mix_all_images = _parse_scene_list(
self.cfg.mix_scene_list, self.cfg.mix_image_root_dir
)[: self.cfg.mix_length]
def __len__(self):
return len(self.all_scenes)
def get_bg_color(self, bg_color):
if bg_color == "white":
bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
elif bg_color == "black":
bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
elif bg_color == "gray":
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
elif bg_color == "random":
bg_color = np.random.rand(3)
elif bg_color == "random_gray":
bg_color = random.uniform(0.3, 0.7)
bg_color = np.array([bg_color] * 3, dtype=np.float32)
elif isinstance(bg_color, float):
bg_color = np.array([bg_color] * 3, dtype=np.float32)
elif isinstance(bg_color, list) or isinstance(bg_color, tuple):
bg_color = np.array(bg_color, dtype=np.float32)
else:
raise NotImplementedError
return bg_color
def load_surface(self, path, num_pc: int = 20480):
if path.endswith(".npy"):
data = np.load(path, allow_pickle=True).tolist()
surface = data["surface_points"] # Nx3
normal = data["surface_normals"] # Nx3
elif path.endswith(".obj") or path.endswith(".glb"):
import trimesh
n_surf_sample = 500000
scene = trimesh.load(path, process=False, force="scene")
meshes = []
for node_name in scene.graph.nodes_geometry:
geom_name = scene.graph[node_name][1]
geometry = scene.geometry[geom_name]
transform = scene.graph[node_name][0]
if isinstance(geometry, trimesh.Trimesh):
geometry.apply_transform(transform)
meshes.append(geometry)
mesh = trimesh.util.concatenate(meshes)
surface, face_indices = trimesh.sample.sample_surface(
mesh, n_surf_sample, sample_color=False
)
normal = mesh.face_normals[face_indices]
else:
raise NotImplementedError(f"Unsupported file format: {path}")
rng = np.random.default_rng()
ind = rng.choice(surface.shape[0], num_pc, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
surface = torch.cat([surface, normal], dim=-1)
return surface
def load_image(
self,
path,
height,
width,
background_color,
rescale: bool = False,
return_mask: bool = False,
remove_bg: bool = False,
idmap_path: Optional[str] = None,
):
image_pil = Image.open(path).resize((width, height))
image = torch.from_numpy(np.array(image_pil)).float() / 255.0
if image_pil.mode == "RGBA":
image_bg = image[:, :, :3] * image[:, :, 3:4] + background_color * (
1 - image[:, :, 3:4]
)
mask = (image[:, :, 3] > 0.5).float()
elif remove_bg and idmap_path is not None:
id_map = torch.from_numpy(
np.array(Image.open(idmap_path).resize((width, height), Image.NEAREST))
)
mask = (id_map > 0).float()
mask_ = mask.unsqueeze(-1).repeat(1, 1, 3)
image_bg = image * mask_ + background_color * (1 - mask_)
else:
image_bg = image
mask = torch.ones_like(image[:, :, 0]).float()
if rescale:
image_bg = image_bg * 2.0 - 1.0
if return_mask:
return image_bg, mask
return image_bg
def load_parts(
self,
rgb_path: str,
idmap_path: str,
indexes: List[int],
height: int,
width: int,
background_color: torch.Tensor,
skip_small_object: bool = False,
small_image_proportion: float = 0.005,
morph_perturb: bool = False, # Whether to apply morphological perturbation
max_kernel_size: int = 5,
p_dilation: float = 0.5,
):
rgb_image = self.load_image(rgb_path, height, width, background_color)
id_map = torch.from_numpy(
np.array(Image.open(idmap_path).resize((width, height), Image.NEAREST))
)
height, width, _ = rgb_image.shape
rgb_list, mask_list, success_list = [], [], []
for idx in indexes:
mask = (id_map == idx).float()
if (
skip_small_object
and mask.sum() <= small_image_proportion * height * width
):
success_list.append(False)
continue
if morph_perturb:
mask = random_morphological_transform(
mask, max_kernel_size=max_kernel_size, p_dilation=p_dilation
)
mask_3c = mask.unsqueeze(-1).repeat(1, 1, 3)
part_rgb = rgb_image * mask_3c + background_color * (1 - mask_3c)
rgb_list.append(part_rgb)
mask_list.append(mask)
success_list.append(True)
return rgb_list, mask_list, success_list
def crop_and_pad(self, rgbs, masks, height, width, padding_ratio=0.1):
cropped_rgbs, cropped_masks = [], []
for rgb, mask in zip(rgbs, masks):
rgb = rgb.permute(2, 0, 1)
# crop
coords = torch.nonzero(mask == 1)
y_min, x_min = coords.min(dim=0).values
y_max, x_max = coords.max(dim=0).values
cropped_rgb = rgb[:, y_min : y_max + 1, x_min : x_max + 1]
cropped_mask = mask[y_min : y_max + 1, x_min : x_max + 1]
h, w = cropped_rgb.shape[1:]
# padding
padding_size = [0, 0, 0, 0] # left, right, top, bottom
if w > h:
padding_size[2] = padding_size[3] = int((w - h) / 2)
h = w
else:
padding_size[0] = padding_size[1] = int((h - w) / 2)
w = h
padding_size = tuple([s + int(w * padding_ratio) for s in padding_size])
padded_rgb = F.pad(cropped_rgb, padding_size, mode="constant", value=1)
padded_mask = F.pad(cropped_mask, padding_size, mode="constant", value=0)
# resize
padded_rgb = F.interpolate(
padded_rgb.unsqueeze(0), (height, width), mode="bilinear"
)[0]
padded_mask = F.interpolate(
padded_mask.unsqueeze(0).unsqueeze(0), (height, width), mode="nearest"
)[0][0]
cropped_rgbs.append(padded_rgb)
cropped_masks.append(padded_mask)
return cropped_rgbs, cropped_masks
def _getitem_scene(self, index):
background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color))
# Surface
scene = self.all_scenes[index]
scene_objects = self.all_objects[index]
surfaces = []
for scene_object in scene_objects:
surface_path = os.path.join(
scene, f"{scene_object}.{self.cfg.surface_suffix}"
)
surface = self.load_surface(
surface_path, self.cfg.num_surface_samples_per_object
)
surfaces.append(surface)
surfaces = torch.stack(surfaces) # (num_instances, num_points, 6)
num_instances = surfaces.shape[0]
# Image
image_dir = self.all_images[index]
image_name = (
random.choice(self.cfg.image_names)
if self.split == "train"
else self.cfg.image_names[0]
)
image_prefix = (
[self.cfg.image_prefix]
if isinstance(self.cfg.image_prefix, str)
else self.cfg.image_prefix
)
image_prefix = (
random.choice(image_prefix) if self.split == "train" else image_prefix[0]
)
image_path = os.path.join(
image_dir, f"{image_prefix}_{image_name}.{self.cfg.image_suffix}"
)
idmap_path = (
os.path.join(
image_dir,
f"{self.cfg.idmap_prefix}_{image_name}.{self.cfg.idmap_suffix}",
)
.replace("_controlnet", "")
.replace("_inpaint", "")
)
# Load image and parts
rgb_scene = (
self.load_image(
image_path,
height=self.cfg.height,
width=self.cfg.width,
background_color=background_color,
remove_bg=self.cfg.remove_scene_bg,
idmap_path=idmap_path,
)
.unsqueeze(0)
.repeat(num_instances, 1, 1, 1)
.permute(0, 3, 1, 2)
)
rgbs, masks, success_list = self.load_parts(
image_path,
idmap_path,
list(range(1, num_instances + 1)),
self.cfg.height,
self.cfg.width,
background_color,
skip_small_object=self.cfg.skip_small_object,
small_image_proportion=self.cfg.small_image_proportion,
morph_perturb=self.cfg.morph_perturb,
max_kernel_size=self.cfg.max_kernel_size,
p_dilation=self.cfg.p_dilation,
)
if len(rgbs) == 0:
return self._getitem(random.randint(0, self.__len__() - 1))
rgb = torch.stack(rgbs).permute(0, 3, 1, 2)
mask = torch.stack(masks).unsqueeze(1)
# Update `surfaces`, `num_instances`, `rgb_scene` according to `success_list`
success_list = torch.tensor(success_list, dtype=torch.bool)
surfaces = surfaces[success_list]
num_instances = surfaces.shape[0]
rgb_scene = rgb_scene[success_list]
if self.cfg.max_num_instances is not None:
if num_instances > self.cfg.max_num_instances:
indices = torch.randperm(num_instances)[: self.cfg.max_num_instances]
surfaces = surfaces[indices]
rgb = rgb[indices]
mask = mask[indices]
rgb_scene = rgb_scene[indices]
num_instances = self.cfg.max_num_instances
# Scene id
scene_id = "-".join(image_dir.split("/")[-2:])
rv = {
"id": scene_id,
"num_instances": num_instances,
"surface": surfaces,
"rgb": rgb,
"mask": mask,
"rgb_scene": (
rgb_scene if self.cfg.use_scene_image else torch.zeros_like(rgb_scene)
),
}
if self.cfg.return_scene:
surface_scene = surfaces.view(-1, *surfaces.shape[2:])
rv.update({"surface_scene": surface_scene})
if self.cfg.return_crop_padded:
cropped_rgbs, cropped_masks = self.crop_and_pad(
rgbs, masks, self.cfg.height_crop_padded, self.cfg.width_crop_padded
)
cropped_rgb = torch.stack(cropped_rgbs)
cropped_mask = torch.stack(cropped_masks).unsqueeze(1)
rv.update(
{"rgb_crop_padded": cropped_rgb, "mask_crop_padded": cropped_mask}
)
if self.cfg.padding:
keys = [
"surface",
"rgb",
"mask",
"rgb_scene",
"rgb_crop_padded",
"mask_crop_padded",
]
if num_instances < self.cfg.num_instances_per_batch:
pad = self.cfg.num_instances_per_batch - num_instances
indices = torch.randint(
0, num_instances, (pad,), device=surfaces.device
)
updated_dict = {
k: torch.cat([v, v[indices]]) if k in keys else v
for k, v in rv.items()
}
else:
indices = torch.randperm(num_instances, device=surfaces.device)
indices = indices[: self.cfg.num_instances_per_batch]
updated_dict = {
k: v[indices] if k in keys else v for k, v in rv.items()
}
updated_dict.update({"num_instances": self.cfg.num_instances_per_batch})
rv = updated_dict
return rv
def _getitem_mix(self, index):
background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color))
surfaces, rgbs, masks = [], [], []
indexes = torch.randint(
0, len(self.mix_all_scenes), (self.cfg.num_instances_per_batch,)
)
for i in indexes:
scene = self.mix_all_scenes[i]
# Surface
surface_path = f"{scene}.{self.cfg.mix_surface_suffix}"
surface = self.load_surface(
surface_path, self.cfg.num_surface_samples_per_object
)
surfaces.append(surface)
# Image
image_dir = self.mix_all_images[i]
image_name = (
random.choice(self.cfg.mix_image_names)
if self.split == "train"
else self.cfg.mix_image_names[0]
)
image_path = os.path.join(
image_dir,
f"{self.cfg.mix_image_prefix}_{image_name}.{self.cfg.mix_image_suffix}",
)
rgb, mask = self.load_image(
image_path,
height=self.cfg.height,
width=self.cfg.width,
background_color=background_color,
return_mask=True,
)
rgbs.append(rgb)
masks.append(mask)
surfaces = torch.stack(surfaces) # (num_instances, num_points, 6)
rgb = torch.stack(rgbs).permute(0, 3, 1, 2)
mask = torch.stack(masks).unsqueeze(1)
# Scene ID
scene_id = self.mix_all_images[indexes[0]].split("/")[-1]
rv = {
"id": scene_id,
"num_instances": 1,
"surface": surfaces,
"rgb": rgb,
"mask": mask,
"rgb_scene": (
rgb if self.cfg.use_scene_image else torch.zeros_like(rgb)
), # here single object is the scene
}
return rv
def _getitem(self, index):
if (
self.split == "train"
and self.cfg.do_mix
and random.random() < self.cfg.do_mix_prob
):
return self._getitem_mix(index)
else:
return self._getitem_scene(index)
def __getitem__(self, index):
try:
return self._getitem(index)
except Exception as e:
print(f"Error in {self.all_scenes[index]}: {e}")
import traceback
traceback.print_exc()
return self.__getitem__(random.randint(0, self.__len__() - 1))
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
pack = lambda t: t.view(-1, *t.shape[2:])
for k in batch.keys():
if k in [
"surface",
"rgb",
"mask",
"rgb_scene",
"rgb_crop_padded",
"mask_crop_padded",
]:
batch[k] = pack(batch[k])
batch["num_instances_per_batch"] = self.cfg.num_instances_per_batch
return batch
class MultiObjectDataModule(pl.LightningDataModule):
cfg: MultiObjectDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(MultiObjectDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = MultiObjectDataset(self.cfg, "train")
if stage in [None, "fit", "validate"]:
self.val_dataset = MultiObjectDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = MultiObjectDataset(self.cfg, "test")
def prepare_data(self):
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
shuffle=True,
collate_fn=self.train_dataset.collate,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.cfg.eval_batch_size,
num_workers=self.cfg.num_workers,
shuffle=False,
collate_fn=self.val_dataset.collate,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.cfg.eval_batch_size,
num_workers=self.cfg.num_workers,
shuffle=False,
collate_fn=self.test_dataset.collate,
)
def predict_dataloader(self) -> DataLoader:
return self.test_dataloader()
if __name__ == "__main__":
import torchvision
from omegaconf import OmegaConf
config_file = "configs/scenediff/training.yaml"
data_cfg = OmegaConf.load(config_file)["data"]
cfg: MultiObjectDataModuleConfig = MultiObjectDataModuleConfig(**data_cfg)
data_module = MultiObjectDataModule(cfg)
data_module.setup()
for batch in data_module.test_dataloader():
print(batch["num_instances"])
for key in [
"rgb",
"mask",
"rgb_scene",
# "rgb_crop_padded",
# "mask_crop_padded",
]:
print(key, batch[key].shape, batch[key].min(), batch[key].max())
torchvision.utils.save_image(
batch[key], f"tmp/{key}.png", nrow=4, normalize=True
)
for key in ["rgb"]:
for i in range(batch[key].shape[0]):
torchvision.utils.save_image(
batch[key][i], f"tmp/{key}_{i}.png", normalize=True
)
break