Spaces:
Sleeping
Sleeping
File size: 2,501 Bytes
6e8e8fb 61c2d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from torch.utils.data import Dataset
# Modified FilteredImageDataset class with Pterygium filtering
class FilteredImageDataset(Dataset):
def __init__(self, dataset, excluded_classes=None):
"""
Create a filtered dataset that excludes specific classes.
Args:
dataset: Original dataset (ImageFolder or similar)
excluded_classes: List of class names to exclude (e.g., ["Pterygium"])
"""
self.dataset = dataset
self.excluded_classes = excluded_classes or []
# Get original class information
self.orig_classes = dataset.classes
self.orig_class_to_idx = dataset.class_to_idx
# Create indices of samples to keep (excluding specified classes)
self.indices = []
for idx, (_, target) in enumerate(dataset.samples):
class_name = self.orig_classes[target]
if class_name not in self.excluded_classes:
self.indices.append(idx)
# Create new class mapping without excluded classes
remaining_classes = [
c for c in self.orig_classes if c not in self.excluded_classes
]
self.classes = remaining_classes
self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}
self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
# Create a mapping from old indices to new indices
self.target_mapping = {}
for old_class, old_idx in self.orig_class_to_idx.items():
if old_class in self.class_to_idx:
self.target_mapping[old_idx] = self.class_to_idx[old_class]
print(f"Filtered out classes: {self.excluded_classes}")
print(f"Remaining classes: {self.classes}")
print(
f"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}"
)
def __getitem__(self, index):
"""Get item from the filtered dataset with remapped class labels."""
orig_idx = self.indices[index]
img, old_target = self.dataset[orig_idx]
# Remap target to new class index
new_target = self.target_mapping[old_target]
return img, new_target
def __len__(self):
"""Return the number of samples in the filtered dataset."""
return len(self.indices)
# Allow transform to be updated
def set_transform(self, transform):
"""Update the transform for the dataset."""
self.dataset.transform = transform
|