Instructions to use hansQAQ/icip_source_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use hansQAQ/icip_source_2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("hansQAQ/icip_source_2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import json | |
| import os | |
| import random | |
| from dataclasses import dataclass, field | |
| import cv2 | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from ..utils.config import parse_structured | |
| from ..utils.typing import * | |
| def _parse_object_list_single(object_list_path: str): | |
| all_objects = [] | |
| if object_list_path.endswith(".json"): | |
| with open(object_list_path) as f: | |
| all_objects = json.loads(f.read()) | |
| else: | |
| raise NotImplementedError | |
| return all_objects | |
| def _parse_object_list(object_list_path: Union[str, List[str]]): | |
| all_objects = [] | |
| if isinstance(object_list_path, str): | |
| object_list_path = [object_list_path] | |
| for object_list_path_ in object_list_path: | |
| all_objects += _parse_object_list_single(object_list_path_) | |
| return all_objects | |
| def _parse_scene_list_single(scene_list_path: str, root_data_dir: str): | |
| all_scenes = [] | |
| if scene_list_path.endswith(".json"): | |
| with open(scene_list_path) as f: | |
| for p in json.loads(f.read()): | |
| all_scenes.append(os.path.join(root_data_dir, p)) | |
| elif scene_list_path.endswith(".txt"): | |
| with open(scene_list_path) as f: | |
| for p in f.readlines(): | |
| p = p.strip() | |
| all_scenes.append(os.path.join(root_data_dir, p)) | |
| else: | |
| raise NotImplementedError | |
| return all_scenes | |
| def _parse_scene_list( | |
| scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]] | |
| ): | |
| all_scenes = [] | |
| if isinstance(scene_list_path, str): | |
| scene_list_path = [scene_list_path] | |
| if isinstance(root_data_dir, str): | |
| root_data_dir = [root_data_dir] | |
| for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir): | |
| all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_) | |
| return all_scenes | |
| def random_morphological_transform( | |
| mask: torch.Tensor, max_kernel_size: int = 5, p_dilation: float = 0.5 | |
| ) -> torch.Tensor: | |
| """ | |
| Randomly dilate or erode the mask | |
| :param mask: [H, W] 0-1 mask | |
| :param max_kernel_size: Maximum kernel size; controls the scale of the morphological operation | |
| :param p_dilation: Probability of dilation; if random sample > p_dilation, erosion is performed | |
| :return: Perturbed mask | |
| """ | |
| mask_np = mask.cpu().numpy().astype(np.uint8) | |
| kernel_size = np.random.randint(1, max_kernel_size + 1) | |
| if kernel_size % 2 == 0: | |
| kernel_size += 1 | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| if np.random.rand() < p_dilation: | |
| mask_np = cv2.dilate(mask_np, kernel, iterations=1) | |
| else: | |
| mask_np = cv2.erode(mask_np, kernel, iterations=1) | |
| return torch.from_numpy(mask_np).float().to(mask.device) | |
| class MultiObjectDataModuleConfig: | |
| scene_list: Any = "" | |
| object_list: Any = "" | |
| # Surface | |
| surface_root_dir: Any = "" | |
| surface_suffix: str = "npy" | |
| num_surface_samples_per_object: int = 20480 | |
| return_scene: bool = False | |
| max_num_instances: Optional[int] = None | |
| padding: bool = True | |
| num_instances_per_batch: Optional[int] = 10 | |
| # Image input | |
| image_root_dir: Any = "" | |
| image_prefix: Any = "render" | |
| image_suffix: str = "webp" | |
| idmap_prefix: str = "semantic" | |
| idmap_suffix: str = "png" | |
| background_color: Union[str, float] = "white" | |
| image_names: List[str] = field(default_factory=lambda: []) | |
| height: int = 768 | |
| width: int = 768 | |
| use_scene_image: bool = True | |
| remove_scene_bg: bool = False | |
| # Data processing | |
| skip_small_object: bool = False | |
| small_image_proportion: float = 0.005 # (16/224)^2 | |
| ## Mask perturbation | |
| morph_perturb: bool = False | |
| max_kernel_size: int = 5 | |
| p_dilation: float = 0.5 | |
| return_crop_padded: bool = False | |
| height_crop_padded: int = 224 | |
| width_crop_padded: int = 224 | |
| # Mix data | |
| do_mix: bool = False | |
| do_mix_prob: float = 0.5 | |
| mix_length: int = 80000 | |
| mix_scene_list: str = "" | |
| mix_image_root_dir: str = "" | |
| mix_surface_root_dir: str = "" | |
| mix_surface_suffix: str = "npy" | |
| mix_image_prefix: str = "render_opaque" | |
| mix_image_names: List[str] = field(default_factory=lambda: []) | |
| mix_image_suffix: str = "webp" | |
| train_indices: Optional[Tuple[Any, Any]] = None | |
| val_indices: Optional[Tuple[Any, Any]] = None | |
| test_indices: Optional[Tuple[Any, Any]] = None | |
| repeat: int = 1 | |
| batch_size: int = 1 | |
| eval_batch_size: int = 1 | |
| num_workers: int = 16 | |
| class MultiObjectDataset(Dataset): | |
| def __init__(self, cfg: Any, split: str = "train") -> None: | |
| super().__init__() | |
| assert split in ["train", "val", "test"] | |
| self.cfg: MultiObjectDataModuleConfig = cfg | |
| self.all_scenes = _parse_scene_list( | |
| self.cfg.scene_list, self.cfg.surface_root_dir | |
| ) | |
| self.all_objects = _parse_object_list(self.cfg.object_list) | |
| if len(self.all_scenes) != len(self.all_objects): | |
| raise ValueError( | |
| f"Number of scenes and objects must be the same, got {len(self.all_scenes)} scenes and {len(self.all_objects)} object lists." | |
| ) | |
| self.all_images = _parse_scene_list( | |
| self.cfg.scene_list, self.cfg.image_root_dir | |
| ) | |
| self.split = split | |
| self.indices = [] | |
| if self.split == "train" and self.cfg.train_indices is not None: | |
| self.indices = (self.cfg.train_indices[0], self.cfg.train_indices[1]) | |
| elif self.split == "val" and self.cfg.val_indices is not None: | |
| self.indices = (self.cfg.val_indices[0], self.cfg.val_indices[1]) | |
| elif self.split == "test" and self.cfg.test_indices is not None: | |
| self.indices = (self.cfg.test_indices[0], self.cfg.test_indices[1]) | |
| else: | |
| self.indices = (0, len(self.all_scenes)) | |
| repeat = self.cfg.repeat if self.split == "train" else 1 | |
| self.all_scenes = self.all_scenes[self.indices[0] : self.indices[1]] * repeat | |
| self.all_objects = self.all_objects[self.indices[0] : self.indices[1]] * repeat | |
| self.all_images = self.all_images[self.indices[0] : self.indices[1]] * repeat | |
| if self.cfg.do_mix: | |
| self.mix_all_scenes = _parse_scene_list( | |
| self.cfg.mix_scene_list, self.cfg.mix_surface_root_dir | |
| )[: self.cfg.mix_length] | |
| self.mix_all_images = _parse_scene_list( | |
| self.cfg.mix_scene_list, self.cfg.mix_image_root_dir | |
| )[: self.cfg.mix_length] | |
| def __len__(self): | |
| return len(self.all_scenes) | |
| def get_bg_color(self, bg_color): | |
| if bg_color == "white": | |
| bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32) | |
| elif bg_color == "black": | |
| bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32) | |
| elif bg_color == "gray": | |
| bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) | |
| elif bg_color == "random": | |
| bg_color = np.random.rand(3) | |
| elif bg_color == "random_gray": | |
| bg_color = random.uniform(0.3, 0.7) | |
| bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
| elif isinstance(bg_color, float): | |
| bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
| elif isinstance(bg_color, list) or isinstance(bg_color, tuple): | |
| bg_color = np.array(bg_color, dtype=np.float32) | |
| else: | |
| raise NotImplementedError | |
| return bg_color | |
| def load_surface(self, path, num_pc: int = 20480): | |
| if path.endswith(".npy"): | |
| data = np.load(path, allow_pickle=True).tolist() | |
| surface = data["surface_points"] # Nx3 | |
| normal = data["surface_normals"] # Nx3 | |
| elif path.endswith(".obj") or path.endswith(".glb"): | |
| import trimesh | |
| n_surf_sample = 500000 | |
| scene = trimesh.load(path, process=False, force="scene") | |
| meshes = [] | |
| for node_name in scene.graph.nodes_geometry: | |
| geom_name = scene.graph[node_name][1] | |
| geometry = scene.geometry[geom_name] | |
| transform = scene.graph[node_name][0] | |
| if isinstance(geometry, trimesh.Trimesh): | |
| geometry.apply_transform(transform) | |
| meshes.append(geometry) | |
| mesh = trimesh.util.concatenate(meshes) | |
| surface, face_indices = trimesh.sample.sample_surface( | |
| mesh, n_surf_sample, sample_color=False | |
| ) | |
| normal = mesh.face_normals[face_indices] | |
| else: | |
| raise NotImplementedError(f"Unsupported file format: {path}") | |
| rng = np.random.default_rng() | |
| ind = rng.choice(surface.shape[0], num_pc, replace=False) | |
| surface = torch.FloatTensor(surface[ind]) | |
| normal = torch.FloatTensor(normal[ind]) | |
| surface = torch.cat([surface, normal], dim=-1) | |
| return surface | |
| def load_image( | |
| self, | |
| path, | |
| height, | |
| width, | |
| background_color, | |
| rescale: bool = False, | |
| return_mask: bool = False, | |
| remove_bg: bool = False, | |
| idmap_path: Optional[str] = None, | |
| ): | |
| image_pil = Image.open(path).resize((width, height)) | |
| image = torch.from_numpy(np.array(image_pil)).float() / 255.0 | |
| if image_pil.mode == "RGBA": | |
| image_bg = image[:, :, :3] * image[:, :, 3:4] + background_color * ( | |
| 1 - image[:, :, 3:4] | |
| ) | |
| mask = (image[:, :, 3] > 0.5).float() | |
| elif remove_bg and idmap_path is not None: | |
| id_map = torch.from_numpy( | |
| np.array(Image.open(idmap_path).resize((width, height), Image.NEAREST)) | |
| ) | |
| mask = (id_map > 0).float() | |
| mask_ = mask.unsqueeze(-1).repeat(1, 1, 3) | |
| image_bg = image * mask_ + background_color * (1 - mask_) | |
| else: | |
| image_bg = image | |
| mask = torch.ones_like(image[:, :, 0]).float() | |
| if rescale: | |
| image_bg = image_bg * 2.0 - 1.0 | |
| if return_mask: | |
| return image_bg, mask | |
| return image_bg | |
| def load_parts( | |
| self, | |
| rgb_path: str, | |
| idmap_path: str, | |
| indexes: List[int], | |
| height: int, | |
| width: int, | |
| background_color: torch.Tensor, | |
| skip_small_object: bool = False, | |
| small_image_proportion: float = 0.005, | |
| morph_perturb: bool = False, # Whether to apply morphological perturbation | |
| max_kernel_size: int = 5, | |
| p_dilation: float = 0.5, | |
| ): | |
| rgb_image = self.load_image(rgb_path, height, width, background_color) | |
| id_map = torch.from_numpy( | |
| np.array(Image.open(idmap_path).resize((width, height), Image.NEAREST)) | |
| ) | |
| height, width, _ = rgb_image.shape | |
| rgb_list, mask_list, success_list = [], [], [] | |
| for idx in indexes: | |
| mask = (id_map == idx).float() | |
| if ( | |
| skip_small_object | |
| and mask.sum() <= small_image_proportion * height * width | |
| ): | |
| success_list.append(False) | |
| continue | |
| if morph_perturb: | |
| mask = random_morphological_transform( | |
| mask, max_kernel_size=max_kernel_size, p_dilation=p_dilation | |
| ) | |
| mask_3c = mask.unsqueeze(-1).repeat(1, 1, 3) | |
| part_rgb = rgb_image * mask_3c + background_color * (1 - mask_3c) | |
| rgb_list.append(part_rgb) | |
| mask_list.append(mask) | |
| success_list.append(True) | |
| return rgb_list, mask_list, success_list | |
| def crop_and_pad(self, rgbs, masks, height, width, padding_ratio=0.1): | |
| cropped_rgbs, cropped_masks = [], [] | |
| for rgb, mask in zip(rgbs, masks): | |
| rgb = rgb.permute(2, 0, 1) | |
| # crop | |
| coords = torch.nonzero(mask == 1) | |
| y_min, x_min = coords.min(dim=0).values | |
| y_max, x_max = coords.max(dim=0).values | |
| cropped_rgb = rgb[:, y_min : y_max + 1, x_min : x_max + 1] | |
| cropped_mask = mask[y_min : y_max + 1, x_min : x_max + 1] | |
| h, w = cropped_rgb.shape[1:] | |
| # padding | |
| padding_size = [0, 0, 0, 0] # left, right, top, bottom | |
| if w > h: | |
| padding_size[2] = padding_size[3] = int((w - h) / 2) | |
| h = w | |
| else: | |
| padding_size[0] = padding_size[1] = int((h - w) / 2) | |
| w = h | |
| padding_size = tuple([s + int(w * padding_ratio) for s in padding_size]) | |
| padded_rgb = F.pad(cropped_rgb, padding_size, mode="constant", value=1) | |
| padded_mask = F.pad(cropped_mask, padding_size, mode="constant", value=0) | |
| # resize | |
| padded_rgb = F.interpolate( | |
| padded_rgb.unsqueeze(0), (height, width), mode="bilinear" | |
| )[0] | |
| padded_mask = F.interpolate( | |
| padded_mask.unsqueeze(0).unsqueeze(0), (height, width), mode="nearest" | |
| )[0][0] | |
| cropped_rgbs.append(padded_rgb) | |
| cropped_masks.append(padded_mask) | |
| return cropped_rgbs, cropped_masks | |
| def _getitem_scene(self, index): | |
| background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) | |
| # Surface | |
| scene = self.all_scenes[index] | |
| scene_objects = self.all_objects[index] | |
| surfaces = [] | |
| for scene_object in scene_objects: | |
| surface_path = os.path.join( | |
| scene, f"{scene_object}.{self.cfg.surface_suffix}" | |
| ) | |
| surface = self.load_surface( | |
| surface_path, self.cfg.num_surface_samples_per_object | |
| ) | |
| surfaces.append(surface) | |
| surfaces = torch.stack(surfaces) # (num_instances, num_points, 6) | |
| num_instances = surfaces.shape[0] | |
| # Image | |
| image_dir = self.all_images[index] | |
| image_name = ( | |
| random.choice(self.cfg.image_names) | |
| if self.split == "train" | |
| else self.cfg.image_names[0] | |
| ) | |
| image_prefix = ( | |
| [self.cfg.image_prefix] | |
| if isinstance(self.cfg.image_prefix, str) | |
| else self.cfg.image_prefix | |
| ) | |
| image_prefix = ( | |
| random.choice(image_prefix) if self.split == "train" else image_prefix[0] | |
| ) | |
| image_path = os.path.join( | |
| image_dir, f"{image_prefix}_{image_name}.{self.cfg.image_suffix}" | |
| ) | |
| idmap_path = ( | |
| os.path.join( | |
| image_dir, | |
| f"{self.cfg.idmap_prefix}_{image_name}.{self.cfg.idmap_suffix}", | |
| ) | |
| .replace("_controlnet", "") | |
| .replace("_inpaint", "") | |
| ) | |
| # Load image and parts | |
| rgb_scene = ( | |
| self.load_image( | |
| image_path, | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| remove_bg=self.cfg.remove_scene_bg, | |
| idmap_path=idmap_path, | |
| ) | |
| .unsqueeze(0) | |
| .repeat(num_instances, 1, 1, 1) | |
| .permute(0, 3, 1, 2) | |
| ) | |
| rgbs, masks, success_list = self.load_parts( | |
| image_path, | |
| idmap_path, | |
| list(range(1, num_instances + 1)), | |
| self.cfg.height, | |
| self.cfg.width, | |
| background_color, | |
| skip_small_object=self.cfg.skip_small_object, | |
| small_image_proportion=self.cfg.small_image_proportion, | |
| morph_perturb=self.cfg.morph_perturb, | |
| max_kernel_size=self.cfg.max_kernel_size, | |
| p_dilation=self.cfg.p_dilation, | |
| ) | |
| if len(rgbs) == 0: | |
| return self._getitem(random.randint(0, self.__len__() - 1)) | |
| rgb = torch.stack(rgbs).permute(0, 3, 1, 2) | |
| mask = torch.stack(masks).unsqueeze(1) | |
| # Update `surfaces`, `num_instances`, `rgb_scene` according to `success_list` | |
| success_list = torch.tensor(success_list, dtype=torch.bool) | |
| surfaces = surfaces[success_list] | |
| num_instances = surfaces.shape[0] | |
| rgb_scene = rgb_scene[success_list] | |
| if self.cfg.max_num_instances is not None: | |
| if num_instances > self.cfg.max_num_instances: | |
| indices = torch.randperm(num_instances)[: self.cfg.max_num_instances] | |
| surfaces = surfaces[indices] | |
| rgb = rgb[indices] | |
| mask = mask[indices] | |
| rgb_scene = rgb_scene[indices] | |
| num_instances = self.cfg.max_num_instances | |
| # Scene id | |
| scene_id = "-".join(image_dir.split("/")[-2:]) | |
| rv = { | |
| "id": scene_id, | |
| "num_instances": num_instances, | |
| "surface": surfaces, | |
| "rgb": rgb, | |
| "mask": mask, | |
| "rgb_scene": ( | |
| rgb_scene if self.cfg.use_scene_image else torch.zeros_like(rgb_scene) | |
| ), | |
| } | |
| if self.cfg.return_scene: | |
| surface_scene = surfaces.view(-1, *surfaces.shape[2:]) | |
| rv.update({"surface_scene": surface_scene}) | |
| if self.cfg.return_crop_padded: | |
| cropped_rgbs, cropped_masks = self.crop_and_pad( | |
| rgbs, masks, self.cfg.height_crop_padded, self.cfg.width_crop_padded | |
| ) | |
| cropped_rgb = torch.stack(cropped_rgbs) | |
| cropped_mask = torch.stack(cropped_masks).unsqueeze(1) | |
| rv.update( | |
| {"rgb_crop_padded": cropped_rgb, "mask_crop_padded": cropped_mask} | |
| ) | |
| if self.cfg.padding: | |
| keys = [ | |
| "surface", | |
| "rgb", | |
| "mask", | |
| "rgb_scene", | |
| "rgb_crop_padded", | |
| "mask_crop_padded", | |
| ] | |
| if num_instances < self.cfg.num_instances_per_batch: | |
| pad = self.cfg.num_instances_per_batch - num_instances | |
| indices = torch.randint( | |
| 0, num_instances, (pad,), device=surfaces.device | |
| ) | |
| updated_dict = { | |
| k: torch.cat([v, v[indices]]) if k in keys else v | |
| for k, v in rv.items() | |
| } | |
| else: | |
| indices = torch.randperm(num_instances, device=surfaces.device) | |
| indices = indices[: self.cfg.num_instances_per_batch] | |
| updated_dict = { | |
| k: v[indices] if k in keys else v for k, v in rv.items() | |
| } | |
| updated_dict.update({"num_instances": self.cfg.num_instances_per_batch}) | |
| rv = updated_dict | |
| return rv | |
| def _getitem_mix(self, index): | |
| background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) | |
| surfaces, rgbs, masks = [], [], [] | |
| indexes = torch.randint( | |
| 0, len(self.mix_all_scenes), (self.cfg.num_instances_per_batch,) | |
| ) | |
| for i in indexes: | |
| scene = self.mix_all_scenes[i] | |
| # Surface | |
| surface_path = f"{scene}.{self.cfg.mix_surface_suffix}" | |
| surface = self.load_surface( | |
| surface_path, self.cfg.num_surface_samples_per_object | |
| ) | |
| surfaces.append(surface) | |
| # Image | |
| image_dir = self.mix_all_images[i] | |
| image_name = ( | |
| random.choice(self.cfg.mix_image_names) | |
| if self.split == "train" | |
| else self.cfg.mix_image_names[0] | |
| ) | |
| image_path = os.path.join( | |
| image_dir, | |
| f"{self.cfg.mix_image_prefix}_{image_name}.{self.cfg.mix_image_suffix}", | |
| ) | |
| rgb, mask = self.load_image( | |
| image_path, | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| return_mask=True, | |
| ) | |
| rgbs.append(rgb) | |
| masks.append(mask) | |
| surfaces = torch.stack(surfaces) # (num_instances, num_points, 6) | |
| rgb = torch.stack(rgbs).permute(0, 3, 1, 2) | |
| mask = torch.stack(masks).unsqueeze(1) | |
| # Scene ID | |
| scene_id = self.mix_all_images[indexes[0]].split("/")[-1] | |
| rv = { | |
| "id": scene_id, | |
| "num_instances": 1, | |
| "surface": surfaces, | |
| "rgb": rgb, | |
| "mask": mask, | |
| "rgb_scene": ( | |
| rgb if self.cfg.use_scene_image else torch.zeros_like(rgb) | |
| ), # here single object is the scene | |
| } | |
| return rv | |
| def _getitem(self, index): | |
| if ( | |
| self.split == "train" | |
| and self.cfg.do_mix | |
| and random.random() < self.cfg.do_mix_prob | |
| ): | |
| return self._getitem_mix(index) | |
| else: | |
| return self._getitem_scene(index) | |
| def __getitem__(self, index): | |
| try: | |
| return self._getitem(index) | |
| except Exception as e: | |
| print(f"Error in {self.all_scenes[index]}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return self.__getitem__(random.randint(0, self.__len__() - 1)) | |
| def collate(self, batch): | |
| batch = torch.utils.data.default_collate(batch) | |
| pack = lambda t: t.view(-1, *t.shape[2:]) | |
| for k in batch.keys(): | |
| if k in [ | |
| "surface", | |
| "rgb", | |
| "mask", | |
| "rgb_scene", | |
| "rgb_crop_padded", | |
| "mask_crop_padded", | |
| ]: | |
| batch[k] = pack(batch[k]) | |
| batch["num_instances_per_batch"] = self.cfg.num_instances_per_batch | |
| return batch | |
| class MultiObjectDataModule(pl.LightningDataModule): | |
| cfg: MultiObjectDataModuleConfig | |
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(MultiObjectDataModuleConfig, cfg) | |
| def setup(self, stage=None) -> None: | |
| if stage in [None, "fit"]: | |
| self.train_dataset = MultiObjectDataset(self.cfg, "train") | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = MultiObjectDataset(self.cfg, "val") | |
| if stage in [None, "test", "predict"]: | |
| self.test_dataset = MultiObjectDataset(self.cfg, "test") | |
| def prepare_data(self): | |
| pass | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.cfg.batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=True, | |
| collate_fn=self.train_dataset.collate, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.cfg.eval_batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=False, | |
| collate_fn=self.val_dataset.collate, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.cfg.eval_batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=False, | |
| collate_fn=self.test_dataset.collate, | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| return self.test_dataloader() | |
| if __name__ == "__main__": | |
| import torchvision | |
| from omegaconf import OmegaConf | |
| config_file = "configs/scenediff/training.yaml" | |
| data_cfg = OmegaConf.load(config_file)["data"] | |
| cfg: MultiObjectDataModuleConfig = MultiObjectDataModuleConfig(**data_cfg) | |
| data_module = MultiObjectDataModule(cfg) | |
| data_module.setup() | |
| for batch in data_module.test_dataloader(): | |
| print(batch["num_instances"]) | |
| for key in [ | |
| "rgb", | |
| "mask", | |
| "rgb_scene", | |
| # "rgb_crop_padded", | |
| # "mask_crop_padded", | |
| ]: | |
| print(key, batch[key].shape, batch[key].min(), batch[key].max()) | |
| torchvision.utils.save_image( | |
| batch[key], f"tmp/{key}.png", nrow=4, normalize=True | |
| ) | |
| for key in ["rgb"]: | |
| for i in range(batch[key].shape[0]): | |
| torchvision.utils.save_image( | |
| batch[key][i], f"tmp/{key}_{i}.png", normalize=True | |
| ) | |
| break | |