| | import os
|
| | import random
|
| | import torch
|
| | import torchvision
|
| | from torchvision.datasets.folder import default_loader
|
| |
|
| | class CustomImageFolderWithNegativeSample(torchvision.datasets.ImageFolder):
|
| |
|
| | def __init__(self, root, transform=None):
|
| | super(CustomImageFolderWithNegativeSample, self).__init__(root, transform=transform)
|
| | self.imgs = self.samples
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | path, target = self.imgs[index]
|
| | img = self.loader(path)
|
| | if self.transform is not None:
|
| | img = self.transform(img)
|
| |
|
| |
|
| | all_classes = list(range(len(self.classes)))
|
| | all_classes.remove(target)
|
| | negative_class = random.choice(all_classes)
|
| | negative_indices = [i for i, (_, class_idx) in enumerate(self.imgs) if class_idx == negative_class]
|
| | negative_index = random.choice(negative_indices)
|
| | negative_path, negative_target = self.imgs[negative_index]
|
| | negative_img = self.loader(negative_path)
|
| | if self.transform is not None:
|
| | negative_img = self.transform(negative_img)
|
| |
|
| | return img, target, negative_img, negative_target
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|