File size: 2,501 Bytes
6e8e8fb
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from torch.utils.data import Dataset


# Modified FilteredImageDataset class with Pterygium filtering
class FilteredImageDataset(Dataset):
    def __init__(self, dataset, excluded_classes=None):
        """
        Create a filtered dataset that excludes specific classes.

        Args:
            dataset: Original dataset (ImageFolder or similar)
            excluded_classes: List of class names to exclude (e.g., ["Pterygium"])
        """
        self.dataset = dataset
        self.excluded_classes = excluded_classes or []

        # Get original class information
        self.orig_classes = dataset.classes
        self.orig_class_to_idx = dataset.class_to_idx

        # Create indices of samples to keep (excluding specified classes)
        self.indices = []
        for idx, (_, target) in enumerate(dataset.samples):
            class_name = self.orig_classes[target]
            if class_name not in self.excluded_classes:
                self.indices.append(idx)

        # Create new class mapping without excluded classes
        remaining_classes = [
            c for c in self.orig_classes if c not in self.excluded_classes
        ]
        self.classes = remaining_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

        # Create a mapping from old indices to new indices
        self.target_mapping = {}
        for old_class, old_idx in self.orig_class_to_idx.items():
            if old_class in self.class_to_idx:
                self.target_mapping[old_idx] = self.class_to_idx[old_class]

        print(f"Filtered out classes: {self.excluded_classes}")
        print(f"Remaining classes: {self.classes}")
        print(
            f"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}"
        )

    def __getitem__(self, index):
        """Get item from the filtered dataset with remapped class labels."""
        orig_idx = self.indices[index]
        img, old_target = self.dataset[orig_idx]

        # Remap target to new class index
        new_target = self.target_mapping[old_target]

        return img, new_target

    def __len__(self):
        """Return the number of samples in the filtered dataset."""
        return len(self.indices)

    # Allow transform to be updated
    def set_transform(self, transform):
        """Update the transform for the dataset."""
        self.dataset.transform = transform