| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | import lightning as L |
| | from pathlib import Path |
| | import pandas as pd |
| | from models.plm import get_model |
| | from models.polybert import PolyEncoder, polymer2psmiles |
| | from argparse import Namespace as Args |
| | from sklearn.model_selection import KFold |
| | from tqdm import tqdm |
| | from torch.utils.data import WeightedRandomSampler |
| |
|
| |
|
| | class EnzymeDataset(Dataset): |
| | def __init__(self, csv_file: str, plm: str): |
| | self.data_list = [] |
| | for i, row in pd.read_csv(csv_file).iterrows(): |
| | self.data_list.append( |
| | (row['category'], row['sequence'].upper(), row['degradation'], row['sequence_id'], row['polymer_id'])) |
| | (cache_dir := Path('cache')).mkdir(parents=True, exist_ok=True) |
| | Path(cache_dir, 'protein').mkdir(parents=True, exist_ok=True) |
| | Path(cache_dir, 'protein', plm).mkdir(parents=True, exist_ok=True) |
| | Path(cache_dir, 'polymer').mkdir(parents=True, exist_ok=True) |
| | if not all(Path(cache_dir, 'protein', plm, f"{seqid}.pt").exists() for _, _, _, seqid, _ in self.data_list): |
| | plm_func = get_model(plm, 'cuda') |
| | for _, seq, _, seqid, _ in tqdm(self.data_list, desc='Encoding enzyme sequences'): |
| | seq_path = Path(cache_dir, 'protein', plm, f'{seqid}.pt') |
| | if not seq_path.exists(): |
| | seq_tensor = plm_func([seq]) |
| | torch.save(seq_tensor, seq_path) |
| |
|
| | def __len__(self): |
| | return len(self.data_list) |
| |
|
| | def __getitem__(self, idx): |
| | return self.data_list[idx] |
| |
|
| |
|
| | class EnzymeDataModule(L.LightningDataModule): |
| | def __init__(self, args: Args): |
| | super().__init__() |
| | self.args = args |
| | self.train_csv = args.train_csv |
| | self.test_csv = args.test_csv |
| | self.batch_size = args.batch_size |
| | self.num_workers = args.num_workers |
| | self.plm = args.plm |
| |
|
| | self.train_val_set = EnzymeDataset(self.train_csv, self.plm) |
| | self.test_set = EnzymeDataset(self.test_csv, self.plm) |
| |
|
| | self.kfold = KFold( |
| | n_splits=args.nfolds, shuffle=True, |
| | random_state=self.args.seed) |
| | self.indices = list(range(len(self.train_val_set))) |
| | self.splits = list(self.kfold.split(self.indices)) |
| |
|
| | def setup_k_fold(self, fold_idx): |
| | train_idx, val_idx = self.splits[fold_idx] |
| |
|
| | self.train_set = torch.utils.data.Subset( |
| | self.train_val_set, train_idx) |
| | self.val_set = torch.utils.data.Subset( |
| | self.train_val_set, val_idx) |
| | self.sampler = self.data_sampler() |
| |
|
| | def data_sampler(self): |
| | |
| | if hasattr(self, 'train_set'): |
| | |
| | indices = self.train_set.indices if hasattr( |
| | self.train_set, 'indices') else range(len(self.train_set)) |
| | labels = [self.train_val_set[i][2] for i in indices] |
| | |
| | label_counts = pd.Series(labels).value_counts() |
| | weights = [1.0 / label_counts[label] for label in labels] |
| | sampler = WeightedRandomSampler( |
| | weights, num_samples=len(weights), replacement=True) |
| | return sampler |
| | else: |
| | raise AttributeError( |
| | 'train_set not initialized. Call setup_k_fold first.') |
| |
|
| | def train_dataloader(self): |
| | return DataLoader( |
| | self.train_set, batch_size=self.batch_size, |
| | |
| | num_workers=self.num_workers, |
| | sampler=self.sampler, |
| | ) |
| |
|
| | def val_dataloader(self): |
| | return DataLoader( |
| | self.val_set, batch_size=self.batch_size, |
| | shuffle=False, num_workers=self.num_workers,) |
| |
|
| | def test_dataloader(self): |
| | return DataLoader( |
| | self.test_set, batch_size=self.batch_size, |
| | shuffle=False, num_workers=self.num_workers) |
| |
|