SAE / attacks /AIM /src /gat /datasets /imagenet.py
Ttius's picture
Upload 192 files
998bb30 verified
from pathlib import Path
from typing import List, Union
import torch
import torchvision
from .builder import DATASET_REGISTRY
from .transforms import resize_256_224, to_color, to_ts
@DATASET_REGISTRY.register()
def imagenet(
data_root: Union[str, Path],
is_train: bool = True,
filter_class: Union[int, List[int]] = None,
) -> torch.utils.data.Dataset:
if isinstance(data_root, str):
data_root = Path(data_root)
if is_train:
data_root = data_root / 'train'
else:
data_root = data_root / 'val'
_transforms = resize_256_224() + to_color() + to_ts()
_ds = torchvision.datasets.ImageFolder(
data_root,
transform=torchvision.transforms.Compose(_transforms),
)
if isinstance(filter_class, int):
filter_class = [filter_class]
if filter_class:
_ds.samples = list(filter(lambda x: x[1] in filter_class, _ds.samples))
return _ds