|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from datasets import Dataset,load_from_disk |
|
|
import sys |
|
|
import pytorch_lightning as pl |
|
|
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
from functools import partial |
|
|
import re |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import pdb |
|
|
|
|
|
|
|
|
class DynamicBatchingDataset(Dataset): |
|
|
def __init__(self, dataset_dict, tokenizer): |
|
|
print('Initializing dataset...') |
|
|
self.dataset_dict = { |
|
|
'attention_mask': [torch.tensor(item) for item in tqdm(dataset_dict['attention_mask'])], |
|
|
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']], |
|
|
'labels': dataset_dict['labels'] |
|
|
} |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset_dict['attention_mask']) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if isinstance(idx, int): |
|
|
return { |
|
|
'input_ids': self.dataset_dict['input_ids'][idx], |
|
|
'attention_mask': self.dataset_dict['attention_mask'][idx], |
|
|
'labels': self.dataset_dict['labels'][idx] |
|
|
} |
|
|
elif isinstance(idx, list): |
|
|
return { |
|
|
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx], |
|
|
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx], |
|
|
'labels': [self.dataset_dict['labels'][i] for i in idx] |
|
|
} |
|
|
else: |
|
|
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}") |
|
|
|
|
|
class CustomDataModule(pl.LightningDataModule): |
|
|
def __init__(self, dataset_path, tokenizer): |
|
|
super().__init__() |
|
|
self.dataset = load_from_disk(dataset_path) |
|
|
self.tokenizer = tokenizer |
|
|
self.dataset_path = dataset_path |
|
|
|
|
|
def peptide_bond_mask(self, smiles_list): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
|
|
of recognized bonds in the positions dictionary and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
max_seq_length = 1035 |
|
|
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
|
|
|
bond_patterns = [ |
|
|
(r'OC\(=O\)', 'ester'), |
|
|
(r'N\(C\)C\(=O\)', 'n_methyl'), |
|
|
(r'N[12]C\(=O\)', 'peptide'), |
|
|
(r'NC\(=O\)', 'peptide'), |
|
|
(r'C\(=O\)N\(C\)', 'n_methyl'), |
|
|
(r'C\(=O\)N[12]?', 'peptide') |
|
|
] |
|
|
|
|
|
for batch_idx, smiles in enumerate(smiles_list): |
|
|
positions = [] |
|
|
used = set() |
|
|
|
|
|
|
|
|
for pattern, bond_type in bond_patterns: |
|
|
for match in re.finditer(pattern, smiles): |
|
|
if not any(p in range(match.start(), match.end()) for p in used): |
|
|
positions.append({ |
|
|
'start': match.start(), |
|
|
'end': match.end(), |
|
|
'type': bond_type, |
|
|
'pattern': match.group() |
|
|
}) |
|
|
used.update(range(match.start(), match.end())) |
|
|
|
|
|
|
|
|
for pos in positions: |
|
|
mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
def peptide_token_mask(self, smiles_list, token_lists): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
|
|
where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
token_lists: List of tokenized SMILES strings (split into tokens). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
token_seq_length = max(len(tokens) for tokens in token_lists) |
|
|
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
|
|
atomwise_masks = self.peptide_bond_mask(smiles_list) |
|
|
|
|
|
|
|
|
for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
|
|
token_seq = token_lists[batch_idx] |
|
|
atom_idx = 0 |
|
|
|
|
|
for token_idx, token in enumerate(token_seq): |
|
|
if token_idx != 0 and token_idx != len(token_seq) - 1: |
|
|
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
|
|
tokenized_masks[batch_idx][token_idx] = 1 |
|
|
atom_idx += len(token) |
|
|
|
|
|
return tokenized_masks |
|
|
|
|
|
def collate_fn(self, batch): |
|
|
item = batch[0] |
|
|
|
|
|
|
|
|
|
|
|
token_array = self.tokenizer.get_token_split(item['input_ids']) |
|
|
bond_mask = self.peptide_token_mask(item['labels'], token_array) |
|
|
|
|
|
return { |
|
|
'input_ids': item['input_ids'], |
|
|
'attention_mask': item['attention_mask'], |
|
|
'bond_mask': bond_mask |
|
|
} |
|
|
|
|
|
def _train_dataset(self): |
|
|
train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer) |
|
|
return train_dataset |
|
|
|
|
|
def _val_dataset(self): |
|
|
val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer) |
|
|
return val_dataset |
|
|
|
|
|
def train_dataloader(self): |
|
|
train_dataset = self._train_dataset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return DataLoader( |
|
|
train_dataset, |
|
|
batch_size=1, |
|
|
collate_fn=self.collate_fn, |
|
|
shuffle=True, |
|
|
num_workers=12, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
val_dataset = self._val_dataset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return DataLoader( |
|
|
val_dataset, |
|
|
batch_size=1, |
|
|
collate_fn=self.collate_fn, |
|
|
num_workers=8, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
class RectifyDataModule(pl.LightningDataModule): |
|
|
def __init__(self, dataset_path): |
|
|
super().__init__() |
|
|
self.dataset_path = dataset_path |
|
|
|
|
|
def collate_fn(self, batch): |
|
|
return { |
|
|
'source_ids': torch.tensor(batch[0]['source_ids']), |
|
|
'target_ids': torch.tensor(batch[0]['target_ids']), |
|
|
'bond_mask': torch.tensor(batch[0]['bond_mask']), |
|
|
} |
|
|
|
|
|
def train_dataloader(self): |
|
|
train_dataset = load_from_disk(os.path.join(self.dataset_path, 'train')) |
|
|
return DataLoader( |
|
|
train_dataset, |
|
|
batch_size=1, |
|
|
collate_fn=self.collate_fn, |
|
|
num_workers=12, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
val_dataset = load_from_disk(os.path.join(self.dataset_path, 'validation')) |
|
|
return DataLoader( |
|
|
val_dataset, |
|
|
batch_size=1, |
|
|
collate_fn=self.collate_fn, |
|
|
num_workers=8, |
|
|
pin_memory=True |
|
|
) |