| | from typing import * |
| | from abc import abstractmethod |
| | import os |
| | import json |
| | import torch |
| | import numpy as np |
| | import pandas as pd |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class StandardDatasetBase(Dataset): |
| | """ |
| | Base class for standard datasets. |
| | |
| | Args: |
| | roots (str): paths to the dataset |
| | """ |
| |
|
| | def __init__(self, |
| | roots: str, |
| | ): |
| | super().__init__() |
| | self.roots = roots.split(',') |
| | self.instances = [] |
| | self.metadata = pd.DataFrame() |
| | |
| | self._stats = {} |
| | for root in self.roots: |
| | key = os.path.basename(root) |
| | self._stats[key] = {} |
| | metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) |
| | self._stats[key]['Total'] = len(metadata) |
| | metadata, stats = self.filter_metadata(metadata) |
| | self._stats[key].update(stats) |
| | self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) |
| | metadata.set_index('sha256', inplace=True) |
| | self.metadata = pd.concat([self.metadata, metadata]) |
| | |
| | @abstractmethod |
| | def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: |
| | pass |
| | |
| | @abstractmethod |
| | def get_instance(self, root: str, instance: str) -> Dict[str, Any]: |
| | pass |
| | |
| | def __len__(self): |
| | return len(self.instances) |
| |
|
| | def __getitem__(self, index) -> Dict[str, Any]: |
| | try: |
| | root, instance = self.instances[index] |
| | return self.get_instance(root, instance) |
| | except Exception as e: |
| | print(e) |
| | return self.__getitem__(np.random.randint(0, len(self))) |
| | |
| | def __str__(self): |
| | lines = [] |
| | lines.append(self.__class__.__name__) |
| | lines.append(f' - Total instances: {len(self)}') |
| | lines.append(f' - Sources:') |
| | for key, stats in self._stats.items(): |
| | lines.append(f' - {key}:') |
| | for k, v in stats.items(): |
| | lines.append(f' - {k}: {v}') |
| | return '\n'.join(lines) |
| |
|
| |
|
| | class TextConditionedMixin: |
| | def __init__(self, roots, **kwargs): |
| | super().__init__(roots, **kwargs) |
| | self.captions = {} |
| | for instance in self.instances: |
| | sha256 = instance[1] |
| | self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions']) |
| | |
| | def filter_metadata(self, metadata): |
| | metadata, stats = super().filter_metadata(metadata) |
| | metadata = metadata[metadata['captions'].notna()] |
| | stats['With captions'] = len(metadata) |
| | return metadata, stats |
| | |
| | def get_instance(self, root, instance): |
| | pack = super().get_instance(root, instance) |
| | text = np.random.choice(self.captions[instance]) |
| | pack['cond'] = text |
| | return pack |
| | |
| | |
| | class ImageConditionedMixin: |
| | def __init__(self, roots, *, image_size=518, **kwargs): |
| | self.image_size = image_size |
| | super().__init__(roots, **kwargs) |
| | |
| | def filter_metadata(self, metadata): |
| | metadata, stats = super().filter_metadata(metadata) |
| | metadata = metadata[metadata[f'cond_rendered']] |
| | stats['Cond rendered'] = len(metadata) |
| | return metadata, stats |
| | |
| | def get_instance(self, root, instance): |
| | pack = super().get_instance(root, instance) |
| | |
| | image_root = os.path.join(root, 'renders_cond', instance) |
| | with open(os.path.join(image_root, 'transforms.json')) as f: |
| | metadata = json.load(f) |
| | n_views = len(metadata['frames']) |
| | view = np.random.randint(n_views) |
| | metadata = metadata['frames'][view] |
| |
|
| | image_path = os.path.join(image_root, metadata['file_path']) |
| | image = Image.open(image_path) |
| |
|
| | alpha = np.array(image.getchannel(3)) |
| | bbox = np.array(alpha).nonzero() |
| | bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] |
| | center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] |
| | hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 |
| | aug_size_ratio = 1.2 |
| | aug_hsize = hsize * aug_size_ratio |
| | aug_center_offset = [0, 0] |
| | aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] |
| | aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] |
| | image = image.crop(aug_bbox) |
| |
|
| | image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) |
| | alpha = image.getchannel(3) |
| | image = image.convert('RGB') |
| | image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 |
| | alpha = torch.tensor(np.array(alpha)).float() / 255.0 |
| | image = image * alpha.unsqueeze(0) |
| | pack['cond'] = image |
| | |
| | return pack |
| | |