| | from transformers import GPT2Config, AutoTokenizer, GPT2Config |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | import transformers |
| | from typing import Optional, Tuple, Callable, List |
| | import torch |
| | import torch.nn as nn |
| | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig |
| | from .utils import CABlock, _GPT2LMHeadModel |
| | from .configuration_prot2text import Prot2TextConfig |
| | from transformers.generation.configuration_utils import GenerationConfig |
| | from transformers.generation.logits_process import LogitsProcessorList |
| | from transformers.generation.stopping_criteria import StoppingCriteriaList |
| |
|
| |
|
| | class Prot2TextModel(PreTrainedModel): |
| | config_class = Prot2TextConfig |
| | _keys_to_ignore_on_load_missing = [r"transformer"] |
| | base_model_prefix = "decoder" |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.gpt_config = GPT2Config.from_dict(config.gpt_config) |
| |
|
| | |
| | self.decoder = _GPT2LMHeadModel(self.gpt_config) |
| |
|
| | |
| | if config.esm: |
| | self.esm_config = PretrainedConfig.from_dict(config.esm_config) |
| | self.esm = transformers.EsmModel(self.esm_config) |
| | self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd) |
| | if config.cross_esm_graph and config.rgcn: |
| | self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)]) |
| | self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon) |
| | |
| | self.config = config |
| | |
| | |
| | def get_encoder(self): |
| | return self.encoder |
| | |
| | def get_decoder(self): |
| | return self.decoder |
| |
|
| | def get_input_embeddings(self): |
| | if hasattr(self, "transformer"): |
| | return self.transformer.wte |
| | return self.decoder.transformer.wte |
| | |
| | def warm_up(self, gpt_model=None, esm_model=None): |
| | if esm_model is not None: |
| | self.esm = transformers.EsmModel.from_pretrained(esm_model) |
| | if gpt_model is not None: |
| | self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False) |
| | self.decoder.resize_token_embeddings(self.gpt_config.vocab_size) |
| | self.decoder.config = self.gpt_config |
| | |
| | |
| | def forward(self, |
| | encoder_input_ids: Optional[torch.LongTensor] = None, |
| | edge_index: Optional[torch.LongTensor] = None, |
| | batch: Optional[torch.LongTensor] = None, |
| | x: Optional[torch.FloatTensor] = None, |
| | edge_type: Optional[torch.LongTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| | past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| | decoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | token_type_ids: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.FloatTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | get_graph_emb: Optional[bool] = False, |
| | **delete_args, |
| | ): |
| | use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict |
| | |
| | |
| | if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3: |
| | decoder_input_ids = decoder_input_ids.squeeze(0) |
| |
|
| | if self.config.esm: |
| | if self.config.prot2text_version=='1.0': |
| | if encoder_input_ids.size()[1] != 1021: |
| | raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021") |
| | |
| | esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state |
| | esm_emb = self.to_embedding(esm_emb) |
| | graph_emb = esm_emb |
| | else: |
| | attention_mask = None |
| | if self.config.prot2text_version=='1.0': |
| | attention_mask = None |
| | if get_graph_emb: |
| | return graph_emb |
| | |
| | transformer_outputs = self.decoder(input_ids=decoder_input_ids, |
| | past_key_values=past_key_values, |
| | attention_mask=decoder_attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | encoder_hidden_states=graph_emb, |
| | encoder_attention_mask=attention_mask, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | |
| | return transformer_outputs |
| | |
| | @torch.no_grad() |
| | def generate_protein_description(self, |
| | protein_sequence=None, |
| | tokenizer=None, |
| | device='cpu' |
| | ): |
| | |
| | if self.config.esm and not self.config.rgcn and protein_sequence==None: |
| | raise ValueError( |
| | "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence" |
| | ) |
| | if self.config.esm: |
| | esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name) |
| | |
| | |
| | seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt") |
| | inputs={} |
| | inputs['encoder_input_ids'] = seq['input_ids'] |
| | inputs['attention_mask'] = seq['attention_mask'] |
| | inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone() |
| | inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id |
| | |
| | self.to(device) |
| | inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} |
| | encoder_state = dict() |
| | encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True) |
| | generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True) |
| | |
| | return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '') |
| | |
| | @torch.no_grad() |
| | def generate(self, |
| | inputs: Optional[torch.Tensor] = None, |
| | generation_config: Optional[GenerationConfig] = None, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| | synced_gpus: Optional[bool] = None, |
| | assistant_model: Optional["PreTrainedModel"] = None, |
| | streamer: Optional["BaseStreamer"] = None, |
| | **kwargs, |
| | ): |
| | encoder_state = self(**kwargs, get_graph_emb=True) |
| | input_ids = kwargs['decoder_input_ids'] |
| | attention_mask = kwargs['decoder_attention_mask'] |
| | kwargs['encoder_attention_mask'] = kwargs['attention_mask'] |
| | if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm: |
| | t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device()) |
| | kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1) |
| | for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length', |
| | '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]: |
| | if key in kwargs.keys(): |
| | kwargs.pop(key) |
| | return self.decoder.generate(input_ids=input_ids, |
| | generation_config=generation_config, |
| | logits_processor=logits_processor, |
| | stopping_criteria=stopping_criteria, |
| | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| | synced_gpus=synced_gpus, |
| | assistant_model=assistant_model, |
| | streamer=streamer, |
| | encoder_outputs={'hidden_states': encoder_state, 'attentions':0}, |
| | **kwargs |
| | ) |