File size: 1,476 Bytes
48cce71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | from torch.utils.data import DataLoader
import torch
from torch.utils.data import DataLoader, DistributedSampler
from src.interface.data_interface import DInterface_base
from src.model.pretrain_model_interface import PretrainModelInterface
from src.data.proteingym_dataset import ProteinGYMDataset
from src.data.msa_dataset import MSADataset
class DInterface(DInterface_base):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.save_hyperparameters()
def setup(self, stage=None):
pass
def data_setup(self, type="proteingym"):
if type == "proteingym":
self.mut_dataset = ProteinGYMDataset(
dms_csv_dir = self.hparams.dms_csv_dir,
dms_pdb_dir = self.hparams.dms_pdb_dir,
dms_reference_csv_path = self.hparams.dms_reference_csv_path,
)
elif type == "msa":
self.msa_dataset = MSADataset(
msa_csv_path = msa_csv_path
)
def train_dataloader(self):
return DataLoader(self.mut_dataset, batch_size=1, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True)
def val_dataloader(self):
return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True)
def test_dataloader(self):
return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True)
|