Robust_vlm / data_utils /neg_dataset.py
Yaning1001's picture
Add files using upload-large-folder tool
64470b3 verified
import os
import random
import torch
import torchvision
from torchvision.datasets.folder import default_loader
class CustomImageFolderWithNegativeSample(torchvision.datasets.ImageFolder):
def __init__(self, root, transform=None):
super(CustomImageFolderWithNegativeSample, self).__init__(root, transform=transform)
self.imgs = self.samples
def __getitem__(self, index):
# positive
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
# negative data
all_classes = list(range(len(self.classes)))
all_classes.remove(target) # remove positive
negative_class = random.choice(all_classes) # choose a negative class
negative_indices = [i for i, (_, class_idx) in enumerate(self.imgs) if class_idx == negative_class] # all the classes
negative_index = random.choice(negative_indices) # choose a negative sample
negative_path, negative_target = self.imgs[negative_index]
negative_img = self.loader(negative_path)
if self.transform is not None:
negative_img = self.transform(negative_img)
return img, target, negative_img, negative_target
# # 使用示例
# imagenet_root = 'path_to_imagenet'
# IN_aug_type = torchvision.transforms.Compose([
# # 在这里定义你的数据增强
# ])
# train_dataset = CustomImageFolderWithNegativeSample(
# os.path.join(imagenet_root, 'train'),
# transform=IN_aug_type
# )