Spaces:
Sleeping
Sleeping
File size: 552 Bytes
dcacefd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torch.utils.data import Subset
from .pl import PocketLigandPairDataset
def get_dataset(config, *args, **kwargs):
name = config.name
root = config.path
if name == 'pl':
dataset = PocketLigandPairDataset(root, *args, **kwargs)
else:
raise NotImplementedError('Unknown dataset: %s' % name)
if 'split' in config:
split = torch.load(config.split)
subsets = {k: Subset(dataset, indices=v) for k, v in split.items()}
return dataset, subsets
else:
return dataset
|