DELM / src /models /dataset.py
xushijie
add app
21f308b
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import lightning as L
from pathlib import Path
import pandas as pd
from models.plm import get_model
from models.polybert import PolyEncoder, polymer2psmiles
from argparse import Namespace as Args
from sklearn.model_selection import KFold
from tqdm import tqdm
from torch.utils.data import WeightedRandomSampler
class EnzymeDataset(Dataset):
def __init__(self, csv_file: str, plm: str):
self.data_list = []
for i, row in pd.read_csv(csv_file).iterrows():
self.data_list.append(
(row['category'], row['sequence'].upper(), row['degradation'], row['sequence_id'], row['polymer_id']))
(cache_dir := Path('cache')).mkdir(parents=True, exist_ok=True)
Path(cache_dir, 'protein').mkdir(parents=True, exist_ok=True)
Path(cache_dir, 'protein', plm).mkdir(parents=True, exist_ok=True)
Path(cache_dir, 'polymer').mkdir(parents=True, exist_ok=True)
if not all(Path(cache_dir, 'protein', plm, f"{seqid}.pt").exists() for _, _, _, seqid, _ in self.data_list):
plm_func = get_model(plm, 'cuda')
for _, seq, _, seqid, _ in tqdm(self.data_list, desc='Encoding enzyme sequences'):
seq_path = Path(cache_dir, 'protein', plm, f'{seqid}.pt')
if not seq_path.exists():
seq_tensor = plm_func([seq])
torch.save(seq_tensor, seq_path)
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
return self.data_list[idx]
class EnzymeDataModule(L.LightningDataModule):
def __init__(self, args: Args):
super().__init__()
self.args = args
self.train_csv = args.train_csv
self.test_csv = args.test_csv
self.batch_size = args.batch_size
self.num_workers = args.num_workers
self.plm = args.plm
self.train_val_set = EnzymeDataset(self.train_csv, self.plm)
self.test_set = EnzymeDataset(self.test_csv, self.plm)
self.kfold = KFold(
n_splits=args.nfolds, shuffle=True,
random_state=self.args.seed)
self.indices = list(range(len(self.train_val_set)))
self.splits = list(self.kfold.split(self.indices))
def setup_k_fold(self, fold_idx):
train_idx, val_idx = self.splits[fold_idx]
self.train_set = torch.utils.data.Subset(
self.train_val_set, train_idx)
self.val_set = torch.utils.data.Subset(
self.train_val_set, val_idx)
self.sampler = self.data_sampler()
def data_sampler(self):
# Get labels for train_set
if hasattr(self, 'train_set'):
# train_set is a Subset, get indices
indices = self.train_set.indices if hasattr(
self.train_set, 'indices') else range(len(self.train_set))
labels = [self.train_val_set[i][2] for i in indices]
# Compute class weights
label_counts = pd.Series(labels).value_counts()
weights = [1.0 / label_counts[label] for label in labels]
sampler = WeightedRandomSampler(
weights, num_samples=len(weights), replacement=True)
return sampler
else:
raise AttributeError(
'train_set not initialized. Call setup_k_fold first.')
def train_dataloader(self):
return DataLoader(
self.train_set, batch_size=self.batch_size,
# shuffle=True,
num_workers=self.num_workers,
sampler=self.sampler,
)
def val_dataloader(self):
return DataLoader(
self.val_set, batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers,)
def test_dataloader(self):
return DataLoader(
self.test_set, batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers)