Spaces:
Configuration error
Configuration error
| from src.DataManager.base import BaseDataManager | |
| from src.DataManager.utils import imread_rgb, imwrite_rgb | |
| import numpy as np | |
| from pathlib import Path | |
| class ImageDataManager(BaseDataManager): | |
| def __init__(self, src_data: Path, output_dir: Path): | |
| self.output_dir: Path = output_dir | |
| self.output_dir.mkdir(exist_ok=True) | |
| self.output_dir = output_dir / "img" | |
| self.output_dir.mkdir(exist_ok=True) | |
| self.data_paths = [] | |
| if src_data.is_file(): | |
| self.data_paths.append(src_data) | |
| elif src_data.is_dir(): | |
| self.data_paths = ( | |
| list(src_data.glob("*.jpg")) | |
| + list(src_data.glob("*.jpeg")) | |
| + list(src_data.glob("*.png")) | |
| ) | |
| assert len(self.data_paths), "Data must be supplied!" | |
| self.data_paths_iter = iter(self.data_paths) | |
| self.last_idx = -1 | |
| def __len__(self): | |
| return len(self.data_paths) | |
| def get(self) -> np.ndarray: | |
| img_path = next(self.data_paths_iter) | |
| self.last_idx += 1 | |
| return imread_rgb(img_path) | |
| def save(self, img: np.ndarray): | |
| filename = "swap_" + Path(self.data_paths[self.last_idx]).name | |
| imwrite_rgb(self.output_dir / filename, img) | |