Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import copy | |
| import re | |
| from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer | |
| from transformers.models.t5.modeling_t5 import T5Stack | |
| from transformers.modeling_outputs import TokenClassifierOutput | |
| from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
| from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier | |
| from utils.lora_utils import LoRAConfig, modify_with_lora | |
| class T5EncoderForTokenClassification(T5PreTrainedModel): | |
| def __init__(self, config: T5Config, class_config): | |
| super().__init__(config) | |
| self.num_labels = class_config.num_labels | |
| self.config = config | |
| self.add_pearson_loss = class_config.add_pearson_loss | |
| self.add_sse_loss = class_config.add_sse_loss | |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.use_cache = False | |
| encoder_config.is_encoder_decoder = False | |
| self.encoder = T5Stack(encoder_config, self.shared) | |
| self.dropout = nn.Dropout(class_config.dropout_rate) | |
| if class_config.adaptor_architecture == 'attention': | |
| self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels) | |
| elif class_config.adaptor_architecture == 'direct': | |
| self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels) | |
| elif class_config.adaptor_architecture == 'conv': | |
| self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers) | |
| elif class_config.adaptor_architecture == 'no-adaptor': | |
| self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels) | |
| else: | |
| raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.') | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| def parallelize(self, device_map=None): | |
| self.device_map = ( | |
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | |
| if device_map is None | |
| else device_map | |
| ) | |
| assert_device_map(self.device_map, len(self.encoder.block)) | |
| self.encoder.parallelize(self.device_map) | |
| self.classifier = self.classifier.to(self.encoder.first_device) | |
| self.model_parallel = True | |
| def deparallelize(self): | |
| self.encoder.deparallelize() | |
| self.encoder = self.encoder.to("cpu") | |
| self.model_parallel = False | |
| self.device_map = None | |
| torch.cuda.empty_cache() | |
| def get_input_embeddings(self): | |
| return self.shared | |
| def set_input_embeddings(self, new_embeddings): | |
| self.shared = new_embeddings | |
| self.encoder.set_input_embeddings(new_embeddings) | |
| def get_encoder(self): | |
| return self.encoder | |
| def _prune_heads(self, heads_to_prune): | |
| """ | |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |
| class PreTrainedModel | |
| """ | |
| for layer, heads in heads_to_prune.items(): | |
| self.encoder.layer[layer].attention.prune_heads(heads) | |
| def forward( | |
| self, | |
| enm_vals = None, | |
| input_ids=None, | |
| attention_mask=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # import pdb; pdb.set_trace() | |
| outputs = self.encoder(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| sequence_output = self.dropout(sequence_output) | |
| #TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512 | |
| logits = self.classifier(sequence_output, enm_vals, attention_mask) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return TokenClassifierOutput( | |
| #loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def PT5_classification_model(half_precision, class_config): | |
| # Load PT5 and tokenizer | |
| # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out) | |
| if not half_precision: | |
| model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False) | |
| tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False) | |
| elif half_precision and torch.cuda.is_available(): | |
| tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=False) | |
| model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=False).to(torch.device('cuda')) | |
| else: | |
| raise ValueError('Half precision can be run on GPU only.') | |
| # Create new Classifier model with PT5 dimensions | |
| class_model=T5EncoderForTokenClassification(model.config,class_config) | |
| # Set encoder and embedding weights to checkpoint weights | |
| class_model.shared=model.shared | |
| class_model.encoder=model.encoder | |
| # Delete the checkpoint model | |
| model=class_model | |
| del class_model | |
| # Print number of trainable parameters | |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| print("ProtT5_Classfier\nTrainable Parameter: "+ str(params)) | |
| # Add model modification lora | |
| config = LoRAConfig('configs/lora_config.yaml') | |
| # Add LoRA layers | |
| model = modify_with_lora(model, config) | |
| # Freeze Embeddings and Encoder (except LoRA) | |
| for (param_name, param) in model.shared.named_parameters(): | |
| param.requires_grad = False | |
| for (param_name, param) in model.encoder.named_parameters(): | |
| param.requires_grad = False | |
| for (param_name, param) in model.named_parameters(): | |
| if re.fullmatch(config.trainable_param_names, param_name): | |
| param.requires_grad = True | |
| # Print trainable Parameter | |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") | |
| return model, tokenizer | |