| from pathlib import Path |
| from typing import Optional |
| import numpy as np |
| import PIL.Image as Image |
| from torch.utils.data import Dataset |
| from vhap.util.log import get_logger |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class ImageFolderDataset(Dataset): |
| def __init__( |
| self, |
| image_folder: Path, |
| background_folder: Optional[Path]=None, |
| background_fname2camId=lambda x: x, |
| image_fname2camId=lambda x: x, |
| ): |
| """ |
| Args: |
| root_folder: Path to dataset with the following directory layout |
| <image_folder>/ |
| |---xx.jpg |
| |---... |
| """ |
| super().__init__() |
| self.image_fname2camId = image_fname2camId |
| self.background_foler = background_folder |
|
|
| logger.info(f"Initializing dataset from folder {image_folder}") |
|
|
| self.image_paths = sorted(list(image_folder.glob('*.jpg'))) |
|
|
| if background_folder is not None: |
| self.backgrounds = {} |
| background_paths = sorted(list((image_folder / background_folder).glob('*.jpg'))) |
|
|
| for background_path in background_paths: |
| bg = np.array(Image.open(background_path)) |
| cam_id = background_fname2camId(background_path.name) |
| self.backgrounds[cam_id] = bg |
| |
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, i): |
| image_path = self.image_paths[i] |
| cam_id = self.image_fname2camId(image_path.name) |
| rgb = np.array(Image.open(image_path)) |
| item = { |
| "rgb": rgb, |
| 'image_path': str(image_path), |
| } |
|
|
| if self.background_foler is not None: |
| item['background'] = self.backgrounds[cam_id] |
|
|
| return item |
|
|
|
|
| if __name__ == "__main__": |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
|
|
| dataset = ImageFolderDataset( |
| image_folder='./xx', |
| img_to_tensor=True, |
| ) |
|
|
| print(len(dataset)) |
|
|
| sample = dataset[0] |
| print(sample.keys()) |
| print(sample["rgb"].shape) |
|
|
| dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) |
| for item in tqdm(dataloader): |
| pass |
|
|