| from torch.utils.data import DataLoader | |
| import torch | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from src.interface.data_interface import DInterface_base | |
| from src.data.protein_dataset import ProteinDataset | |
| from src.model.pretrain_model_interface import PretrainModelInterface | |
| class DInterface(DInterface_base): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| def setup(self, stage=None): | |
| pass | |
| def data_setup(self, target="all"): | |
| pretrain_model_interface = None | |
| if self.finetune_type == "adapter": | |
| pretrain_model_interface = PretrainModelInterface(self.hparams.pretrain_model_name, batch_size=self.hparams.pretrain_batch_size, max_length=self.hparams.seq_len, sequence_only=self.hparams.sequence_only, task_type=self.hparams.task_type) | |
| if target == "all": | |
| self.train_set = ProteinDataset(self.hparams.train_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) | |
| self.val_set = ProteinDataset(self.hparams.val_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) | |
| self.test_set = ProteinDataset(self.hparams.test_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) | |
| elif target == "test": | |
| self.test_set = ProteinDataset(self.hparams.test_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) | |
| def val_dataloader(self): | |
| return DataLoader(self.val_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) | |
| def test_dataloader(self): | |
| return DataLoader(self.test_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) | |
| def data_process_fn(self, data_list): | |
| if self.hparams.finetune_type == 'adapter': | |
| name_list = [] | |
| mask_list = [] | |
| label_list = [] | |
| embedding_list = [] | |
| smiles = [] | |
| for data in data_list: | |
| name_list.append(data['name']) | |
| mask_list.append(data['attention_mask']) | |
| label_list.append(data['label']) | |
| embedding_list.append(data['embedding']) | |
| if data.get('smiles') is not None: | |
| smiles.append(data['smiles']) | |
| return {'name': name_list, | |
| 'attention_mask': torch.stack(mask_list, dim=0), | |
| 'label': torch.stack(label_list, dim=0), | |
| 'embedding': torch.stack(embedding_list, dim=0), | |
| 'smiles': torch.stack(smiles, dim=0) if len(smiles) > 0 else None, | |
| } | |
| else: | |
| return data_list | |