Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| import imagehash | |
| from typing import Callable | |
| from datetime import datetime as dt | |
| from abc import ABC, abstractmethod | |
| _DATASET_AVG_MEAN = 129.38489987766278 | |
| _DATASET_AVG_STD = 54.084109207654805 | |
| def save_to_file(location: str = './extracted_paths.txt') -> Callable: | |
| def outer_wrapper(fn: Callable) -> Callable: | |
| def inner_wrapper(*args, **kwargs): | |
| paths: list[str] = fn(*args, **kwargs) | |
| if kwargs.get('to_file'): | |
| with open(location, 'a') as file: | |
| file.write('\nFiles to remove [TIMESTAMP {}]:\n'.format(dt.now().strftime('%Y%m%d%H%M%S'))) | |
| for p in paths: | |
| file.write(f'{p}\n') | |
| return paths | |
| return inner_wrapper | |
| return outer_wrapper | |
| def visualize(show_limit: int = -1) -> Callable: | |
| def outer_wrapper(fn: Callable) -> Callable: | |
| def inner_wrapper(*args, **kwargs): | |
| paths: list[str] = fn(*args, **kwargs) | |
| if kwargs.get('visualize_'): | |
| if show_limit != -1: | |
| paths = paths[:show_limit] | |
| num_cols = 8 | |
| num_rows = len(paths) // num_cols + 1 | |
| fig = plt.figure(figsize=(8, 8)) | |
| for i, path in enumerate(paths, start=1): | |
| plt.subplot(num_rows, num_cols, i) | |
| plt.imshow(Image.open(path), cmap='gray') | |
| plt.title(f'{Path(path).parent.name}', fontsize=7) | |
| plt.axis('off') | |
| fig.tight_layout() | |
| plt.tight_layout() | |
| fig.subplots_adjust(hspace=0.6, top=0.97) | |
| plt.show() | |
| return paths | |
| return inner_wrapper | |
| return outer_wrapper | |
| class DataFilter(ABC): | |
| def __init__(self): | |
| self.paths = [] | |
| def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: | |
| pass | |
| def clear(self) -> None: | |
| pass | |
| def filter(self) -> bool: | |
| pass | |
| def _load_data(dir_: str) -> tuple[list[np.ndarray], list[str], list[str]]: | |
| images = [] | |
| class_names = [] | |
| paths = [] | |
| for path in Path(dir_).glob('**/*.jpg'): | |
| label = path.parent.name | |
| image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) | |
| if image is not None and label is not None: | |
| images.append(np.array(image)) | |
| class_names.append(label) | |
| paths.append(str(path)) | |
| return images, class_names, paths | |
| class DataFilterCompose(DataFilter): | |
| def __init__(self, components: list[DataFilter]): | |
| super().__init__() | |
| self.components = components | |
| def build(components: list[DataFilter]) -> DataFilter: | |
| return DataFilterCompose(components) | |
| def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: | |
| extracted_paths = [] | |
| for component in self.components: | |
| cur_extracted_paths = component.extract(data_dir, | |
| visualize_=visualize_, | |
| to_file=to_file) | |
| extracted_paths += cur_extracted_paths | |
| self.paths += extracted_paths | |
| return extracted_paths | |
| def clear(self) -> None: | |
| for component in self.components: | |
| component.clear() | |
| def filter(self): | |
| for component in self.components: | |
| component.filter() | |
| def add_component(self, component: DataFilter, position: int) -> None: | |
| self.components.insert(position, component) | |
| def rm_component(self, position: int) -> None: | |
| self.components.pop(position) | |
| class StatsDataFilter(DataFilter): | |
| _OPTIM_MEAN_THRESH = 107 | |
| _OPTIM_STD_THRESH = 51 | |
| def __init__(self, data_avg_mean: float = None, data_avg_std: float = None, console_output: bool = False): | |
| super().__init__() | |
| self.data_avg_mean = data_avg_mean | |
| self.data_avg_std = data_avg_std | |
| self.console_output = console_output | |
| def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: | |
| if self.data_avg_mean is None or self.data_avg_std is None: | |
| stats = self._compute_dataset_stats(data_dir) | |
| self.data_avg_mean = stats['avg_mean'] | |
| self.data_avg_std = stats['avg_std'] | |
| extracted_paths = self._extract_outliers_by_stats( | |
| data_dir, | |
| self.data_avg_mean, | |
| self.data_avg_std, | |
| StatsDataFilter._OPTIM_MEAN_THRESH, | |
| StatsDataFilter._OPTIM_STD_THRESH, | |
| self.console_output) | |
| self.paths += extracted_paths | |
| return extracted_paths | |
| def clear(self) -> None: | |
| self.paths.clear() | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Paths memory cleared.') | |
| def filter(self) -> bool: | |
| has_error = False | |
| for path in self.paths: | |
| if not Path(path).exists(): | |
| has_error = True | |
| continue | |
| os.remove(path) | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Removed {path}') | |
| return has_error | |
| def _extract_outliers_by_stats(cls, | |
| data_root: str | Path, | |
| dataset_avg_mean: float, | |
| dataset_avg_std: float, | |
| mean_thresh: float, | |
| std_thresh: float, | |
| console_output: bool = False) -> list[str]: | |
| outlier_paths = [] | |
| count = 0 | |
| _, _, paths = StatsDataFilter._load_data(data_root) | |
| total_len = len(paths) | |
| for path in iter(paths): | |
| img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| if abs(dataset_avg_mean - np.mean(img)) > mean_thresh or abs( | |
| dataset_avg_std - np.std(img)) > std_thresh: | |
| outlier_paths.append(path) | |
| if console_output: | |
| count += 1 | |
| print(f'[{cls.__name__}]: Computed {count}/{total_len} images ({count / total_len * 100:.2f}%)') | |
| return outlier_paths | |
| def _compute_dataset_stats(data_dir: str) -> dict[str, float]: | |
| img_paths = list(Path(data_dir).glob('**/*.jpg')) | |
| num_images = len(img_paths) | |
| mean_sum = 0 | |
| std_sum = 0 | |
| for img_path in img_paths: | |
| img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) | |
| img_mean = np.mean(img) | |
| img_std = np.std(img) | |
| mean_sum += img_mean | |
| std_sum += img_std | |
| avg_mean = mean_sum / num_images | |
| avg_std = std_sum / num_images | |
| stats_dict = { | |
| 'avg_mean': avg_mean, | |
| 'avg_std': avg_std, | |
| } | |
| return stats_dict | |
| class PcaDataFilter(DataFilter): | |
| _OPTIM_NUM_COMPONENTS = 4 | |
| _OPTIM_ERROR_THRESH = 87 | |
| def __init__(self, console_output: bool = False): | |
| super().__init__() | |
| self.console_output = console_output | |
| def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: | |
| extracted_paths = self._extract_outliers_with_pca(data_dir) | |
| self.paths += extracted_paths | |
| return extracted_paths | |
| def clear(self) -> None: | |
| self.paths.clear() | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Paths memory cleared.') | |
| def filter(self) -> bool: | |
| has_error = False | |
| for path in self.paths: | |
| if not Path(path).exists(): | |
| has_error = True | |
| continue | |
| os.remove(path) | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Removed {path}') | |
| return has_error | |
| def _extract_outliers_with_pca(dir_: str | Path) -> list[str]: | |
| x, _, img_paths = PcaDataFilter._load_data(dir_) | |
| x = np.array(x) | |
| num_samples, height, width = x.shape | |
| X_flattened = x.reshape(num_samples, height * width) | |
| outlier_indices = PcaDataFilter._detect_outliers_with_pca(X_flattened, | |
| PcaDataFilter._OPTIM_NUM_COMPONENTS, | |
| PcaDataFilter._OPTIM_ERROR_THRESH) | |
| img_paths_to_remove = [img_paths[i] for i in outlier_indices.tolist()] | |
| return img_paths_to_remove | |
| def _detect_outliers_with_pca(orig_data: np.ndarray, | |
| num_components: int, | |
| error_thresh: float) -> np.ndarray: | |
| pca = PCA(n_components=num_components) | |
| X_reduced = pca.fit_transform(orig_data) | |
| X_reconstructed = pca.inverse_transform(X_reduced) | |
| reconstruction_errors = np.sqrt(np.mean((orig_data - X_reconstructed) ** 2, axis=1)) | |
| outlier_indices = np.where(reconstruction_errors > error_thresh)[0] | |
| return outlier_indices | |
| class DHashDuplicateFilter(DataFilter): | |
| def __init__(self, hash_size: int = 8, console_output: bool = False): | |
| super().__init__() | |
| self.hash_size = hash_size | |
| self.console_output = console_output | |
| def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: | |
| _, _, paths = self._load_data(data_dir) | |
| hashes = set() | |
| duplicates = [] | |
| for path in paths: | |
| hash_ = imagehash.dhash(Image.open(path), self.hash_size) | |
| if hash_ in hashes: | |
| duplicates.append(path) | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Duplicate found at {path}') | |
| else: | |
| hashes.add(hash_) | |
| self.paths += duplicates | |
| return duplicates | |
| def clear(self) -> None: | |
| self.paths.clear() | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Paths memory cleared.') | |
| def filter(self) -> bool: | |
| has_error = False | |
| for path in self.paths: | |
| if not Path(path).exists(): | |
| has_error = True | |
| continue | |
| os.remove(path) | |
| if self.console_output: | |
| print(f'[{self.__class__.__name__}]: Removed {path}') | |
| return has_error | |
| if __name__ == '__main__': | |
| dataset_dir = Path('./dataset') | |
| stats_filter = StatsDataFilter(_DATASET_AVG_MEAN, _DATASET_AVG_STD, True) | |
| pca_filter = PcaDataFilter(console_output=True) | |
| duplicate_filter = DHashDuplicateFilter(console_output=True) | |
| compose = DataFilterCompose.build([ | |
| stats_filter, | |
| pca_filter, | |
| duplicate_filter | |
| ]) | |
| # You may set the value of visualize_ or to_file parameters to True | |
| # to plot extracted images or save paths to a file. | |
| stats_filter.extract(dataset_dir, visualize_=False, to_file=False) | |
| # WARNING: uncommenting the line below will irreversibly remove dataset files | |
| # compose.filter() | |