Spaces:
Running
Running
| """ Hugging Face utilities for model loading and pipeline creation. """ | |
| from typing import Optional, List, Dict, Union | |
| from datasets import Dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| EncoderDecoderModel, | |
| AutoModelForCausalLM, | |
| pipeline, | |
| GenerationConfig, | |
| ) | |
| from transformers.pipelines.pt_utils import KeyDataset | |
| from tqdm import tqdm | |
| import torch | |
| def get_encoder_decoder_model( | |
| pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1", | |
| pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1", | |
| max_length: Optional[int] = 512, | |
| tie_encoder_decoder: bool = False, | |
| ) -> EncoderDecoderModel: | |
| """ Get the EncoderDecoderModel model for the PROTAC splitter. | |
| Args: | |
| pretrained_encoder (str): The pretrained model to use for the encoder. Default: "seyonec/ChemBERTa-zinc-base-v1" | |
| pretrained_decoder (str): The pretrained model to use for the decoder. Default: "seyonec/ChemBERTa-zinc-base-v1" | |
| max_length (int): The maximum length of the input sequence. Default: 512 | |
| tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False | |
| Returns: | |
| EncoderDecoderModel: The EncoderDecoderModel model for the PROTAC splitter | |
| """ | |
| bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained( | |
| pretrained_encoder, | |
| pretrained_decoder, | |
| tie_encoder_decoder=tie_encoder_decoder, | |
| ) | |
| print(f"Number of parameters: {bert2bert.num_parameters():,}") | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder) | |
| # Tokenizer-related configs | |
| bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id | |
| bert2bert.config.eos_token_id = tokenizer.sep_token_id | |
| bert2bert.config.pad_token_id = tokenizer.pad_token_id | |
| bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size | |
| # Generation configs | |
| # NOTE: See full list of configurations can be found here: https://huggingface.co/docs/transformers/v4.33.3/en/main_classes/text_generation#transformers.GenerationConfig | |
| bert2bert.encoder.config.max_length = max_length | |
| bert2bert.decoder.config.max_length = max_length | |
| def setup_gen(config): | |
| config.do_sample = True | |
| config.num_beams = 5 | |
| config.top_k = 20 | |
| config.max_length = 512 | |
| # config.max_new_tokens = 512 | |
| return config | |
| bert2bert.config = setup_gen(bert2bert.config) | |
| bert2bert.encoder.config = setup_gen(bert2bert.encoder.config) | |
| bert2bert.decoder.config = setup_gen(bert2bert.decoder.config) | |
| bert2bert.decoder.config.is_decoder = True | |
| bert2bert.generation_config = setup_gen(bert2bert.generation_config) | |
| # bert2bert.config.do_sample = True | |
| # bert2bert.config.num_beams = 5 | |
| # bert2bert.config.top_k = 20 | |
| # bert2bert.config.max_length=512 | |
| # bert2bert.config.max_new_tokens=512 | |
| # bert2bert.generation_config.max_new_tokens = 512 | |
| # bert2bert.generation_config.min_new_tokens = 512 | |
| # bert2bert.config.max_new_tokens = 514 | |
| # bert2bert.config.early_stopping = True | |
| # bert2bert.config.length_penalty = 2.0 | |
| # # bert2bert.config.no_repeat_ngram_size = 3 # Default: 0 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| bert2bert.to(device) | |
| return bert2bert | |
| def get_causal_model( | |
| pretrained_model: str = "seyonec/ChemBERTa-zinc-base-v1", | |
| max_length: Optional[int] = 512, | |
| ) -> AutoModelForCausalLM: | |
| """ Get the causal language model for the PROTAC splitter. | |
| Args: | |
| pretrained_model (str): The pretrained model to use for the causal language model. Default: "seyonec/ChemBERTa-zinc-base-v1" | |
| max_length (int): The maximum length of the input sequence. Default: 512 | |
| Returns: | |
| AutoModelForCausalLM: The causal language model for the PROTAC splitter | |
| """ | |
| model = AutoModelForCausalLM.from_pretrained(pretrained_model, is_decoder=True) | |
| # model.is_decoder = True # It might not be necessary, but it's good to be explicit | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| return model | |
| # REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71 | |
| GENERATION_STRATEGY_PARAMS = { | |
| "greedy": {"num_beams": 1, "do_sample": False}, | |
| "contrastive_search": {"penalty_alpha": 0.1, "top_k": 10}, | |
| "multinomial_sampling": {"num_beams": 1, "do_sample": True}, | |
| "beam_search_decoding": {"num_beams": 5, "do_sample": False, "num_return_sequences": 5}, | |
| "beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True, "num_return_sequences": 5}, | |
| "diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0, "num_return_sequences": 5}, | |
| } | |
| def avail_generation_strategies() -> List[str]: | |
| """ Get the available generation strategies. """ | |
| return list(GENERATION_STRATEGY_PARAMS.keys()) | |
| def get_generation_config(generation_strategy: str) -> GenerationConfig: | |
| """ Get the generation config for the given generation strategy. """ | |
| return GenerationConfig( | |
| max_length=512, | |
| max_new_tokens=512, | |
| **GENERATION_STRATEGY_PARAMS[generation_strategy], | |
| ) | |
| def get_pipeline( | |
| model_name: str, | |
| token: str, | |
| is_causal_language_model: bool, | |
| generation_strategy: Optional[str] = None, | |
| num_return_sequences: int = 1, | |
| device: Optional[Union[int, str]] = None, | |
| ) -> pipeline: | |
| """ Get the pipeline for the given model name and generation strategy. | |
| """ | |
| device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| if is_causal_language_model and generation_strategy is None: | |
| print('Loading pipeline for causal language models...') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left') | |
| return pipeline( | |
| "text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| token=token, | |
| device=device, | |
| num_return_sequences=num_return_sequences, | |
| ) | |
| if is_causal_language_model and generation_strategy is not None: | |
| print('Loading pipeline for causal language models...') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left') | |
| return pipeline( | |
| "text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| token=token, | |
| device=device, | |
| generation_config=get_generation_config(generation_strategy), | |
| ) | |
| if not is_causal_language_model and generation_strategy is None: | |
| print('Loading pipeline for sequence-to-sequence models...') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) | |
| return pipeline( | |
| "text2text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| token=token, | |
| device=device, | |
| ) | |
| if not is_causal_language_model and generation_strategy is not None: | |
| print('Loading pipeline for sequence-to-sequence models...') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) | |
| return pipeline( | |
| "text2text-generation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| token=token, | |
| device=device, | |
| generation_config=get_generation_config(generation_strategy), | |
| ) | |
| def run_causal_pipeline( | |
| pipe: pipeline, | |
| test_ds: Dataset, | |
| batch_size: int, | |
| smiles_column: str = 'prompt', | |
| ) -> List[Dict[str, str]]: | |
| """ Run the pipeline for causal language models and return the predictions. | |
| Args: | |
| pipe (pipeline): The pipeline object to use for generating predictions. | |
| test_ds (Dataset): The test dataset to generate predictions for. | |
| batch_size (int): The batch size to use for generating predictions. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries containing the predictions. | |
| """ | |
| preds = [] | |
| for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size): | |
| generated_text = [p['generated_text'] for p in pred] | |
| # Remove the prompt from the generated text | |
| generated_text = ['.'.join(t.split('.')[1:]) for t in generated_text] | |
| # Add the predictions to the list | |
| p = {f'pred_n{i}': t for i, t in enumerate(generated_text)} | |
| preds.append(p) | |
| return preds | |
| def run_seq2seq_pipeline( | |
| pipe: pipeline, | |
| test_ds: Dataset, | |
| batch_size: int, | |
| smiles_column: str = 'text', | |
| ) -> List[Dict[str, str]]: | |
| """ Run the pipeline for sequence-to-sequence models and return the predictions. | |
| Args: | |
| pipe (pipeline): The pipeline object to use for generating predictions. | |
| test_ds (Dataset): The test dataset to generate predictions for. | |
| batch_size (int): The batch size to use for generating predictions. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries containing the predictions. | |
| """ | |
| preds = [] | |
| for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size): | |
| p = {f'pred_n{i}': p['generated_text'] for i, p in enumerate(pred)} | |
| preds.append(p) | |
| return preds | |
| def run_pipeline( | |
| pipe: pipeline, | |
| test_ds: Dataset, | |
| batch_size: int, | |
| is_causal_language_model: bool, | |
| smiles_column: str = 'text', | |
| ) -> List[Dict[str, str]]: | |
| """ Run the pipeline and return the predictions. | |
| Args: | |
| pipe (pipeline): The pipeline object to use for generating predictions. | |
| test_ds (Dataset): The test dataset to generate predictions for. | |
| batch_size (int): The batch size to use for generating predictions. | |
| is_causal_language_model (bool): Whether the model is a causal language model or not. | |
| smiles_column (str): The column name in the dataset that contains the SMILES strings. Default: 'text' | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries containing the beam-size predictions in the format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...] | |
| """ | |
| if is_causal_language_model: | |
| return run_causal_pipeline(pipe, test_ds, batch_size, smiles_column) | |
| else: | |
| return run_seq2seq_pipeline(pipe, test_ds, batch_size, smiles_column) |