from huggingface_hub import hf_hub_download import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.utils.data import DataLoader import re import numpy as np import os import pandas as pd import copy import gc import transformers, datasets from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack from transformers.utils.model_parallel_utils import assert_device_map, get_device_map from transformers import T5EncoderModel, T5Tokenizer from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel from transformers import AutoTokenizer from transformers import TrainingArguments, Trainer, set_seed from transformers import DataCollatorForTokenClassification from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union # for custom DataCollator from transformers.data.data_collator import DataCollatorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from datasets import Dataset from scipy.special import expit #import peft #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc ffn_head=False #False transformer_head=False custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc class ClassConfig: def __init__(self, dropout=0.2, num_labels=3): self.dropout_rate = dropout self.num_labels = num_labels class T5EncoderForTokenClassification(T5PreTrainedModel): def __init__(self, config: T5Config, class_config: ClassConfig): super().__init__(config) self.num_labels = class_config.num_labels self.config = config self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) self.dropout = nn.Dropout(class_config.dropout_rate) # Initialize different heads based on class_config if cnn_head: self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) self.classifier = nn.Linear(512, class_config.num_labels) elif ffn_head: # Multi-layer feed-forward network (FFN) head self.ffn = nn.Sequential( nn.Linear(config.hidden_size, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, class_config.num_labels) ) elif transformer_head: # Transformer layer head encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) else: # Default classification head self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) self.post_init() # Model parallel self.model_parallel = False self.device_map = None def parallelize(self, device_map=None): self.device_map = ( get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.encoder.block)) self.encoder.parallelize(self.device_map) self.classifier = self.classifier.to(self.encoder.first_device) self.model_parallel = True def deparallelize(self): self.encoder.deparallelize() self.encoder = self.encoder.to("cpu") self.model_parallel = False self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) def get_encoder(self): return self.encoder def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) # Forward pass through the selected head if cnn_head: # CNN head sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN cnn_output = self.cnn(sequence_output) cnn_output = F.relu(cnn_output) cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier logits = self.classifier(cnn_output) elif ffn_head: # FFN head logits = self.ffn(sequence_output) elif transformer_head: # Transformer head transformer_output = self.transformer_encoder(sequence_output) logits = self.classifier(transformer_output) else: # Default classification head logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) ) valid_logits = active_logits[active_labels != -100] valid_labels = active_labels[active_labels != -100] valid_labels = valid_labels.to(valid_logits.device) valid_labels = valid_labels.long() loss = loss_fct(valid_logits, valid_labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # Modifies an existing transformer and introduce the LoRA layers class CustomLoRAConfig: def __init__(self): self.lora_rank = 4 self.lora_init_scale = 0.01 self.lora_modules = ".*SelfAttention|.*EncDecAttention" self.lora_layers = "q|k|v|o" self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*" self.lora_scaling_rank = 1 # lora_modules and lora_layers are speicified with regular expressions # see https://www.w3schools.com/python/python_regex.asp for reference class LoRALinear(nn.Module): def __init__(self, linear_layer, rank, scaling_rank, init_scale): super().__init__() self.in_features = linear_layer.in_features self.out_features = linear_layer.out_features self.rank = rank self.scaling_rank = scaling_rank self.weight = linear_layer.weight self.bias = linear_layer.bias if self.rank > 0: self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) if init_scale < 0: self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) else: self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) if self.scaling_rank: self.multi_lora_a = nn.Parameter( torch.ones(self.scaling_rank, linear_layer.in_features) + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale ) if init_scale < 0: self.multi_lora_b = nn.Parameter( torch.ones(linear_layer.out_features, self.scaling_rank) + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale ) else: self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) def forward(self, input): if self.scaling_rank == 1 and self.rank == 0: # parsimonious implementation for ia3 and lora scaling if self.multi_lora_a.requires_grad: hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) else: hidden = F.linear(input, self.weight, self.bias) if self.multi_lora_b.requires_grad: hidden = hidden * self.multi_lora_b.flatten() return hidden else: # general implementation for lora (adding and scaling) weight = self.weight if self.scaling_rank: weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank if self.rank: weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank return F.linear(input, weight, self.bias) def extra_repr(self): return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank ) def modify_with_lora(transformer, config): for m_name, module in dict(transformer.named_modules()).items(): if re.fullmatch(config.lora_modules, m_name): for c_name, layer in dict(module.named_children()).items(): if re.fullmatch(config.lora_layers, c_name): assert isinstance( layer, nn.Linear ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." setattr( module, c_name, LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), ) return transformer def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True): # Load model and tokenizer if "ankh" in checkpoint : model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, resume_download=True) elif "prot_t5" in checkpoint: # possible to load the half precision model (thanks to @pawel-rezo for pointing that out) if half_precision and deepspeed: tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False, resume_download=True) model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16, resume_download=True).to(torch.device('cuda')) else: model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True) tokenizer = T5Tokenizer.from_pretrained(checkpoint, resume_download=True) elif "ProstT5" in checkpoint: if half_precision and deepspeed: tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False, resume_download=True) model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16, resume_download=True).to(torch.device('cuda')) else: model = T5EncoderModel.from_pretrained(checkpoint, resume_download=True) tokenizer = T5Tokenizer.from_pretrained(checkpoint, resume_download=True) # Create new Classifier model with PT5 dimensions class_config=ClassConfig(num_labels=num_labels) class_model=T5EncoderForTokenClassification(model.config,class_config) # Set encoder and embedding weights to checkpoint weights class_model.shared=model.shared class_model.encoder=model.encoder # Delete the checkpoint model and clear memory del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() model = class_model del class_model if full == True: return model, tokenizer # Print number of trainable parameters model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("T5_Classfier\nTrainable Parameter: "+ str(params)) if custom_lora: #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed # Add model modification lora config = CustomLoRAConfig() # Add LoRA layers model = modify_with_lora(model, config) # Freeze Embeddings and Encoder (except LoRA) for (param_name, param) in model.shared.named_parameters(): param.requires_grad = False for (param_name, param) in model.encoder.named_parameters(): param.requires_grad = False for (param_name, param) in model.named_parameters(): if re.fullmatch(config.trainable_param_names, param_name): param.requires_grad = True else: # lora modification peft_config = LoraConfig( r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"] ) model = inject_adapter_in_model(peft_config, model) # Unfreeze the prediction head for (param_name, param) in model.classifier.named_parameters(): param.requires_grad = True # Print trainable Parameter model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("T5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") return model, tokenizer class EsmForTokenClassificationCustom(EsmPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"cnn", r"ffn", r"transformer"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.esm = EsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) if cnn_head: self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) self.classifier = nn.Linear(512, config.num_labels) elif ffn_head: # Multi-layer feed-forward network (FFN) as an alternative head self.ffn = nn.Sequential( nn.Linear(config.hidden_size, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, config.num_labels) ) elif transformer_head: # Transformer layer as an alternative head encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.classifier = nn.Linear(config.hidden_size, config.num_labels) else: # Default classification head self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) if cnn_head: sequence_output = sequence_output.transpose(1, 2) sequence_output = self.cnn(sequence_output) sequence_output = sequence_output.transpose(1, 2) logits = self.classifier(sequence_output) elif ffn_head: logits = self.ffn(sequence_output) elif transformer_head: # Apply transformer encoder for the transformer head sequence_output = self.transformer_encoder(sequence_output) logits = self.classifier(sequence_output) else: logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) ) valid_logits = active_logits[active_labels != -100] valid_labels = active_labels[active_labels != -100] valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0') loss = loss_fct(valid_logits, valid_labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _init_weights(self, module): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() # based on transformers DataCollatorForTokenClassification @dataclass class DataCollatorForTokenClassificationESM(DataCollatorMixin): """ Data collator that will dynamically pad the inputs received, as well as the labels. Args: tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): The tokenizer used for encoding the data. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single sequence is provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). pad_to_multiple_of (`int`, *optional*): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). label_pad_token_id (`int`, *optional*, defaults to -100): The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 return_tensors: str = "pt" def torch_call(self, features): import torch label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] batch = self.tokenizer.pad( no_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) if labels is None: return batch sequence_length = batch["input_ids"].shape[1] padding_side = self.tokenizer.padding_side def to_list(tensor_or_iterable): if isinstance(tensor_or_iterable, torch.Tensor): return tensor_or_iterable.tolist() return list(tensor_or_iterable) if padding_side == "right": batch[label_name] = [ # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels # changed to pad the special tokens at the beginning and end of the sequence [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels ] else: batch[label_name] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels ] batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float) return batch def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" import torch # Tensorize if necessary. if isinstance(examples[0], (list, tuple, np.ndarray)): examples = [torch.tensor(e, dtype=torch.long) for e in examples] length_of_first = examples[0].size(0) # Check if padding is necessary. are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): return torch.stack(examples, dim=0) # If yes, check if we have a `pad_token`. if tokenizer._pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({tokenizer.__class__.__name__}) does not have a pad token." ) # Creating the full tensor and filling it with our data. max_length = max(x.size(0) for x in examples) if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) for i, example in enumerate(examples): if tokenizer.padding_side == "right": result[i, : example.shape[0]] = example else: result[i, -example.shape[0] :] = example return result def tolist(x): if isinstance(x, list): return x elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import x = x.numpy() return x.tolist() #load ESM2 models def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True): tokenizer = AutoTokenizer.from_pretrained(checkpoint) if half_precision and deepspeed: model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels, ignore_mismatched_sizes=True, torch_dtype = torch.float16) else: model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, num_labels = num_labels, ignore_mismatched_sizes=True) if full == True: return model, tokenizer peft_config = LoraConfig( r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] ) model = inject_adapter_in_model(peft_config, model) #model.gradient_checkpointing_enable() # Unfreeze the prediction head for (param_name, param) in model.classifier.named_parameters(): param.requires_grad = True return model, tokenizer def load_model(checkpoint, max_length): full=False deepspeed=False mixed=False num_labels=2 print(checkpoint, num_labels, mixed, full, deepspeed) # Determine model type and load accordingly if "esm" in checkpoint: model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed) else: model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed) # Download the file local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth") # Load the best model state with memory mapping for efficiency state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(state_dict) # Clear state_dict from memory immediately after loading del state_dict gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return model, tokenizer