Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import logging | |
| from typing import Optional, Union | |
| import torch | |
| from datasets import load_dataset, concatenate_datasets, Dataset | |
| from transformers import AutoTokenizer | |
| from rdkit import Chem | |
| from protac_splitter.evaluation import split_prediction | |
| def randomize_smiles_dataset( | |
| batch: dict, | |
| repeat: int = 1, | |
| prob: float = 0.5, | |
| apply_to_text: bool = True, | |
| apply_to_labels: bool = False, | |
| ) -> dict: | |
| """ Randomize SMILES in a batch of data. | |
| Args: | |
| batch (dict): Batch of data with "text" and "labels" keys. | |
| repeat (int, optional): Number of times to repeat the randomization. Defaults to 1. | |
| prob (float, optional): Probability of randomizing SMILES. Defaults to 0.5. | |
| apply_to_text (bool, optional): Whether to apply randomization to text. Defaults to True. | |
| apply_to_labels (bool, optional): Whether to apply randomization to labels. Defaults to False. | |
| Returns: | |
| dict: Randomized batch of data. | |
| """ | |
| new_texts, new_labels = [], [] | |
| for text, label in zip(batch["text"], batch["labels"]): | |
| try: | |
| mol_text = Chem.MolFromSmiles(text) | |
| mol_label = Chem.MolFromSmiles(label) | |
| except Exception: | |
| logging.error("Failed to convert SMILES to Mol!") | |
| new_texts.append(text) | |
| new_labels.append(label) | |
| continue | |
| if random.random() < prob: | |
| if apply_to_text: | |
| rand_texts = [Chem.MolToSmiles(mol_text, canonical=False, doRandom=True) for _ in range(repeat)] | |
| else: | |
| rand_texts = [text] * repeat | |
| if apply_to_labels: | |
| rand_labels = [Chem.MolToSmiles(mol_label, canonical=False, doRandom=True) for _ in range(repeat)] | |
| else: | |
| rand_labels = [label] * repeat | |
| new_texts.extend(rand_texts) | |
| new_labels.extend(rand_labels) | |
| else: | |
| new_texts.append(text) | |
| new_labels.append(label) | |
| return {"text": new_texts, "labels": new_labels} | |
| def process_data_to_model_inputs( | |
| batch, | |
| tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", | |
| encoder_max_length: int = 512, | |
| decoder_max_length: int = 512, | |
| ): | |
| if isinstance(tokenizer, str): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| # tokenize the inputs and labels | |
| inputs = tokenizer(batch["text"], truncation=True, max_length=encoder_max_length) | |
| outputs = tokenizer(batch["labels"], truncation=True, max_length=decoder_max_length) | |
| batch["input_ids"] = inputs.input_ids | |
| batch["attention_mask"] = inputs.attention_mask | |
| batch["labels"] = outputs.input_ids.copy() | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # batch["input_ids"] = batch["input_ids"].to(device) | |
| # batch["attention_mask"] = batch["attention_mask"].to(device) | |
| # batch["labels"] = batch["labels"].to(device) | |
| # Because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. | |
| # We have to make sure that the PAD token is ignored when calculating the loss. | |
| # NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss. | |
| # NOTE: The following is already done in the DataCollatorForSeq2Seq | |
| # batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]] | |
| return batch | |
| def get_fragments_in_labels(labels: str, linkers_only_as_labels: bool = True) -> list[str]: | |
| """ Get the fragments in the labels. | |
| Args: | |
| labels (str): The labels. | |
| linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to True. | |
| Returns: | |
| list[str]: The fragments in the labels. | |
| """ | |
| ligands = split_prediction(labels) | |
| if linkers_only_as_labels: | |
| return ligands.get("linker", None) | |
| if None in ligands.values(): | |
| return None | |
| return f"{ligands['e3']}.{ligands['poi']}" | |
| def load_tokenized_dataset( | |
| dataset_dir: str, | |
| dataset_config: str = 'default', | |
| tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", | |
| batch_size: int = 512, | |
| encoder_max_length: int = 512, | |
| decoder_max_length: int = 512, | |
| token: Optional[str] = None, | |
| num_proc_map: int = 1, | |
| randomize_smiles: bool = False, | |
| randomize_smiles_prob: float = 0.5, | |
| randomize_smiles_repeat: int = 1, | |
| randomize_text: bool = True, | |
| randomize_labels: bool = False, | |
| cache_dir: Optional[str] = None, | |
| all_fragments_as_labels: bool = True, | |
| linkers_only_as_labels: bool = False, | |
| causal_language_modeling: bool = False, | |
| train_size_ratio: float = 1.0, | |
| ) -> Dataset: | |
| """ Load dataset and tokenize it. | |
| Args: | |
| dataset_dir (str): The directory of the dataset or the name of the data on the Hugging Face Hub. | |
| dataset_config (str, optional): The configuration of the dataset. Defaults to 'default'. | |
| tokenizer (AutoTokenizer | str, optional): The tokenizer to use for tokenization. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Defaults to "seyonec/ChemBERTa-zinc-base-v1". | |
| batch_size (int, optional): The batch size for tokenization. Defaults to 512. | |
| encoder_max_length (int, optional): The maximum length of the encoder input sequence. Defaults to 512. | |
| decoder_max_length (int, optional): The maximum length of the decoder input sequence. Defaults to 512. | |
| token (Optional[str], optional): The Hugging Face API token. Defaults to None. | |
| num_proc_map (int, optional): The number of processes to use for mapping. Defaults to 1. | |
| randomize_smiles (bool, optional): Whether to randomize SMILES. Defaults to False. | |
| randomize_smiles_prob (float, optional): The probability of randomizing SMILES. Defaults to 0.5. | |
| randomize_smiles_repeat (int, optional): The number of times to repeat the randomization. Defaults to 1. | |
| randomize_text (bool, optional): Whether to randomize text. Defaults to True. | |
| randomize_labels (bool, optional): Whether to randomize labels. Defaults to False. | |
| cache_dir (Optional[str], optional): The directory to cache the dataset. Defaults to None. | |
| all_fragments_as_labels (bool, optional): Whether to get all fragments in the labels. Defaults to True. | |
| linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to False. | |
| causal_language_modeling (bool, optional): Whether to use causal language modeling. Defaults to False. | |
| train_size_ratio (float, optional): The ratio of the training dataset to use. Defaults to 1.0. | |
| Returns: | |
| Dataset: The tokenized dataset. | |
| """ | |
| if isinstance(tokenizer, str): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| if os.path.exists(dataset_dir): | |
| # NOTE: We need a different argument to load a dataset from disk: | |
| dataset = load_dataset( | |
| dataset_dir, | |
| data_dir=dataset_config, | |
| ) | |
| print(f"Dataset loaded from disk at: \"{dataset_dir}\". Length: {dataset.num_rows}") | |
| else: | |
| dataset = load_dataset( | |
| dataset_dir, | |
| dataset_config, | |
| token=token, | |
| cache_dir=cache_dir, | |
| ) | |
| print(f"Dataset loaded from hub. Length: {dataset.num_rows}") | |
| if train_size_ratio < 1.0 and train_size_ratio > 0: | |
| # Reduce the size of the training dataset but just selecting a fraction of the samples | |
| dataset["train"] = dataset["train"].select(range(int(train_size_ratio * dataset["train"].num_rows))) | |
| print(f"Reduced training dataset size to {train_size_ratio}. Length: {dataset.num_rows}") | |
| elif train_size_ratio > 1.0 or train_size_ratio < 0: | |
| raise ValueError("train_size_ratio must be between 0 and 1.") | |
| if not all_fragments_as_labels: | |
| dataset = dataset.map( | |
| lambda x: { | |
| "text": x["text"], | |
| "labels": get_fragments_in_labels(x["labels"], linkers_only_as_labels), | |
| }, | |
| batched=False, | |
| num_proc=num_proc_map, | |
| load_from_cache_file=True, | |
| desc="Getting fragments in labels", | |
| ) | |
| # Filter out the samples with None labels | |
| dataset = dataset.filter(lambda x: x["labels"] is not None) | |
| if linkers_only_as_labels: | |
| print(f"Set labels to linkers only. Length: {dataset.num_rows}") | |
| else: | |
| print(f"Set labels to E3 and WH only. Length: {dataset.num_rows}") | |
| if randomize_smiles: | |
| dataset["train"] = dataset["train"].map( | |
| randomize_smiles_dataset, | |
| batched=True, | |
| batch_size=batch_size, | |
| fn_kwargs={ | |
| "repeat": randomize_smiles_repeat, | |
| "prob": randomize_smiles_prob, | |
| "apply_to_text": randomize_text, | |
| "apply_to_labels": randomize_labels, | |
| }, | |
| num_proc=num_proc_map, | |
| load_from_cache_file=True, | |
| desc="Randomizing SMILES", | |
| ) | |
| print(f"Randomized SMILES in dataset. Length: {dataset.num_rows}") | |
| if causal_language_modeling: | |
| dataset = dataset.map( | |
| lambda x: { | |
| "text": x["text"] + "." + x["labels"], | |
| "labels": x["labels"], | |
| }, | |
| batched=False, | |
| num_proc=num_proc_map, | |
| load_from_cache_file=True, | |
| desc="Setting labels to text", | |
| ) | |
| print(f"Appended labels to text. Length: {dataset.num_rows}") | |
| # NOTE: Remove the "labels" column if causal language modeling, since the | |
| # DataCollatorForLM will automatically set the labels to the input_ids. | |
| dataset = dataset.map( | |
| process_data_to_model_inputs, | |
| batched=True, | |
| batch_size=batch_size, | |
| remove_columns=["text", "labels"] if causal_language_modeling else ["text"], | |
| fn_kwargs={ | |
| "tokenizer": tokenizer, | |
| "encoder_max_length": encoder_max_length, | |
| "decoder_max_length": decoder_max_length, | |
| }, | |
| num_proc=num_proc_map, | |
| load_from_cache_file=True, | |
| desc="Tokenizing dataset", | |
| ) | |
| print(f"Tokenized dataset. Length: {dataset.num_rows}") | |
| return dataset | |
| def load_trl_dataset( | |
| tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", | |
| token: Optional[str] = None, | |
| max_length: int = 512, | |
| dataset_name: str = "ailab-bio/PROTAC-Splitter-Dataset", | |
| ds_config: str = "standard", | |
| ds_unalabeled: Optional[str] = None, | |
| ) -> Dataset: | |
| if isinstance(tokenizer, str): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| # Load training data | |
| train_dataset = load_dataset( | |
| dataset_name, | |
| ds_config, | |
| split="train", | |
| token=token, | |
| ) | |
| train_dataset = train_dataset.rename_column("text", "query") | |
| train_dataset = train_dataset.remove_columns(["labels"]) | |
| if ds_unalabeled is not None: | |
| # Load un-labelled data | |
| unlabeled_dataset = load_dataset( | |
| dataset_name, | |
| ds_unalabeled, | |
| split="train", | |
| token=token, | |
| ) | |
| unlabeled_dataset = unlabeled_dataset.rename_column("text", "query") | |
| unlabeled_dataset = unlabeled_dataset.remove_columns(["labels"]) | |
| # Concatenate datasets row-wise | |
| dataset = concatenate_datasets([train_dataset, unlabeled_dataset]) | |
| else: | |
| dataset = train_dataset | |
| def tokenize(sample, tokenizer, max_length=512): | |
| input_ids = tokenizer.encode(sample["query"], padding="max_length", max_length=max_length) | |
| return {"input_ids": input_ids, "query": sample["query"]} | |
| return dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=False) | |
| def data_collator_for_trl(batch): | |
| return { | |
| "input_ids": [torch.tensor(x["input_ids"]) for x in batch], | |
| "query": [x["query"] for x in batch], | |
| } |