| import os | |
| import shutil | |
| from typing import List, Tuple | |
| from PIL import Image | |
| from datasets import load_dataset | |
| dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data") | |
| SAMPLES_DIR = "samples" | |
| def prepare_samples(): | |
| samples: List[Tuple[str, str, str]] = [] | |
| for sample in dataset: | |
| rgb: Image.Image = sample["rgb"] | |
| depth: Image.Image = sample["depth"] | |
| gt: Image.Image = sample["gt"] | |
| name: str = sample["name"] | |
| dir_path = os.path.join(SAMPLES_DIR, name) | |
| shutil.rmtree(dir_path, ignore_errors=True) | |
| os.makedirs(dir_path, exist_ok=True) | |
| rgb_path = os.path.join(dir_path, f"rgb.jpg") | |
| rgb.save(rgb_path) | |
| depth_path = os.path.join(dir_path, f"depth.jpg") | |
| depth.save(depth_path) | |
| gt_path = os.path.join(dir_path, f"gt.png") | |
| gt.save(gt_path) | |
| samples.append([rgb_path, depth_path, gt_path]) | |
| return samples | |