| import torch | |
| from torch.utils.data import Subset | |
| from .pl import PocketLigandPairDataset | |
| import random | |
| def get_dataset(config, *args, **kwargs): | |
| name = config.name | |
| root = config.path | |
| if name == 'pl': | |
| dataset = PocketLigandPairDataset(root, *args, **kwargs) | |
| else: | |
| raise NotImplementedError('Unknown dataset: %s' % name) | |
| if 'split' in config: | |
| split_by_name = torch.load(config.split) | |
| split = {k: [dataset.name2id[n] for n in names if n in dataset.name2id] for k, names in split_by_name.items()} | |
| subsets = {k:Subset(dataset, indices=v) for k, v in split.items()} | |
| return dataset, subsets | |
| else: | |
| return dataset | |