ICLR_FLAG / utils /datasets /__init__.py
zaixizhang
renew
10efe81
raw
history blame contribute delete
692 Bytes
import torch
from torch.utils.data import Subset
from .pl import PocketLigandPairDataset
import random
def get_dataset(config, *args, **kwargs):
name = config.name
root = config.path
if name == 'pl':
dataset = PocketLigandPairDataset(root, *args, **kwargs)
else:
raise NotImplementedError('Unknown dataset: %s' % name)
if 'split' in config:
split_by_name = torch.load(config.split)
split = {k: [dataset.name2id[n] for n in names if n in dataset.name2id] for k, names in split_by_name.items()}
subsets = {k:Subset(dataset, indices=v) for k, v in split.items()}
return dataset, subsets
else:
return dataset