Spaces:
Sleeping
Sleeping
| from torch.utils.data import Dataset | |
| # Modified FilteredImageDataset class with Pterygium filtering | |
| class FilteredImageDataset(Dataset): | |
| def __init__(self, dataset, excluded_classes=None): | |
| """ | |
| Create a filtered dataset that excludes specific classes. | |
| Args: | |
| dataset: Original dataset (ImageFolder or similar) | |
| excluded_classes: List of class names to exclude (e.g., ["Pterygium"]) | |
| """ | |
| self.dataset = dataset | |
| self.excluded_classes = excluded_classes or [] | |
| # Get original class information | |
| self.orig_classes = dataset.classes | |
| self.orig_class_to_idx = dataset.class_to_idx | |
| # Create indices of samples to keep (excluding specified classes) | |
| self.indices = [] | |
| for idx, (_, target) in enumerate(dataset.samples): | |
| class_name = self.orig_classes[target] | |
| if class_name not in self.excluded_classes: | |
| self.indices.append(idx) | |
| # Create new class mapping without excluded classes | |
| remaining_classes = [ | |
| c for c in self.orig_classes if c not in self.excluded_classes | |
| ] | |
| self.classes = remaining_classes | |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)} | |
| self.idx_to_class = {v: k for k, v in self.class_to_idx.items()} | |
| # Create a mapping from old indices to new indices | |
| self.target_mapping = {} | |
| for old_class, old_idx in self.orig_class_to_idx.items(): | |
| if old_class in self.class_to_idx: | |
| self.target_mapping[old_idx] = self.class_to_idx[old_class] | |
| print(f"Filtered out classes: {self.excluded_classes}") | |
| print(f"Remaining classes: {self.classes}") | |
| print( | |
| f"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}" | |
| ) | |
| def __getitem__(self, index): | |
| """Get item from the filtered dataset with remapped class labels.""" | |
| orig_idx = self.indices[index] | |
| img, old_target = self.dataset[orig_idx] | |
| # Remap target to new class index | |
| new_target = self.target_mapping[old_target] | |
| return img, new_target | |
| def __len__(self): | |
| """Return the number of samples in the filtered dataset.""" | |
| return len(self.indices) | |
| # Allow transform to be updated | |
| def set_transform(self, transform): | |
| """Update the transform for the dataset.""" | |
| self.dataset.transform = transform | |