Spaces:
Sleeping
Sleeping
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from torch.utils.data import DataLoader | |
| import re | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import copy | |
| import transformers, datasets | |
| from transformers.modeling_outputs import TokenClassifierOutput | |
| from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack | |
| from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel | |
| from transformers import AutoTokenizer | |
| from transformers import TrainingArguments, Trainer, set_seed | |
| from transformers import DataCollatorForTokenClassification | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple, Union | |
| # for custom DataCollator | |
| from transformers.data.data_collator import DataCollatorMixin | |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
| from transformers.utils import PaddingStrategy | |
| from datasets import Dataset | |
| from scipy.special import expit | |
| #import peft | |
| #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig | |
| cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc | |
| ffn_head=False #False | |
| transformer_head=False | |
| custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc | |
| class ClassConfig: | |
| def __init__(self, dropout=0.2, num_labels=3): | |
| self.dropout_rate = dropout | |
| self.num_labels = num_labels | |
| class T5EncoderForTokenClassification(T5PreTrainedModel): | |
| def __init__(self, config: T5Config, class_config: ClassConfig): | |
| super().__init__(config) | |
| self.num_labels = class_config.num_labels | |
| self.config = config | |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.use_cache = False | |
| encoder_config.is_encoder_decoder = False | |
| self.encoder = T5Stack(encoder_config, self.shared) | |
| self.dropout = nn.Dropout(class_config.dropout_rate) | |
| # Initialize different heads based on class_config | |
| if cnn_head: | |
| self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) | |
| self.classifier = nn.Linear(512, class_config.num_labels) | |
| elif ffn_head: | |
| # Multi-layer feed-forward network (FFN) head | |
| self.ffn = nn.Sequential( | |
| nn.Linear(config.hidden_size, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, class_config.num_labels) | |
| ) | |
| elif transformer_head: | |
| # Transformer layer head | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) | |
| self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) | |
| else: | |
| # Default classification head | |
| self.classifier = nn.Linear(config.hidden_size, class_config.num_labels) | |
| self.post_init() | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| def parallelize(self, device_map=None): | |
| self.device_map = ( | |
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | |
| if device_map is None | |
| else device_map | |
| ) | |
| assert_device_map(self.device_map, len(self.encoder.block)) | |
| self.encoder.parallelize(self.device_map) | |
| self.classifier = self.classifier.to(self.encoder.first_device) | |
| self.model_parallel = True | |
| def deparallelize(self): | |
| self.encoder.deparallelize() | |
| self.encoder = self.encoder.to("cpu") | |
| self.model_parallel = False | |
| self.device_map = None | |
| torch.cuda.empty_cache() | |
| def get_input_embeddings(self): | |
| return self.shared | |
| def set_input_embeddings(self, new_embeddings): | |
| self.shared = new_embeddings | |
| self.encoder.set_input_embeddings(new_embeddings) | |
| def get_encoder(self): | |
| return self.encoder | |
| def _prune_heads(self, heads_to_prune): | |
| for layer, heads in heads_to_prune.items(): | |
| self.encoder.layer[layer].attention.prune_heads(heads) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| sequence_output = self.dropout(sequence_output) | |
| # Forward pass through the selected head | |
| if cnn_head: | |
| # CNN head | |
| sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN | |
| cnn_output = self.cnn(sequence_output) | |
| cnn_output = F.relu(cnn_output) | |
| cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier | |
| logits = self.classifier(cnn_output) | |
| elif ffn_head: | |
| # FFN head | |
| logits = self.ffn(sequence_output) | |
| elif transformer_head: | |
| # Transformer head | |
| transformer_output = self.transformer_encoder(sequence_output) | |
| logits = self.classifier(transformer_output) | |
| else: | |
| # Default classification head | |
| logits = self.classifier(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| active_loss = attention_mask.view(-1) == 1 | |
| active_logits = logits.view(-1, self.num_labels) | |
| active_labels = torch.where( | |
| active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) | |
| ) | |
| valid_logits = active_logits[active_labels != -100] | |
| valid_labels = active_labels[active_labels != -100] | |
| valid_labels = valid_labels.to(valid_logits.device) | |
| valid_labels = valid_labels.long() | |
| loss = loss_fct(valid_logits, valid_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return TokenClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| # Modifies an existing transformer and introduce the LoRA layers | |
| class CustomLoRAConfig: | |
| def __init__(self): | |
| self.lora_rank = 4 | |
| self.lora_init_scale = 0.01 | |
| self.lora_modules = ".*SelfAttention|.*EncDecAttention" | |
| self.lora_layers = "q|k|v|o" | |
| self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*" | |
| self.lora_scaling_rank = 1 | |
| # lora_modules and lora_layers are speicified with regular expressions | |
| # see https://www.w3schools.com/python/python_regex.asp for reference | |
| class LoRALinear(nn.Module): | |
| def __init__(self, linear_layer, rank, scaling_rank, init_scale): | |
| super().__init__() | |
| self.in_features = linear_layer.in_features | |
| self.out_features = linear_layer.out_features | |
| self.rank = rank | |
| self.scaling_rank = scaling_rank | |
| self.weight = linear_layer.weight | |
| self.bias = linear_layer.bias | |
| if self.rank > 0: | |
| self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) | |
| if init_scale < 0: | |
| self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) | |
| else: | |
| self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) | |
| if self.scaling_rank: | |
| self.multi_lora_a = nn.Parameter( | |
| torch.ones(self.scaling_rank, linear_layer.in_features) | |
| + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale | |
| ) | |
| if init_scale < 0: | |
| self.multi_lora_b = nn.Parameter( | |
| torch.ones(linear_layer.out_features, self.scaling_rank) | |
| + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale | |
| ) | |
| else: | |
| self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) | |
| def forward(self, input): | |
| if self.scaling_rank == 1 and self.rank == 0: | |
| # parsimonious implementation for ia3 and lora scaling | |
| if self.multi_lora_a.requires_grad: | |
| hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) | |
| else: | |
| hidden = F.linear(input, self.weight, self.bias) | |
| if self.multi_lora_b.requires_grad: | |
| hidden = hidden * self.multi_lora_b.flatten() | |
| return hidden | |
| else: | |
| # general implementation for lora (adding and scaling) | |
| weight = self.weight | |
| if self.scaling_rank: | |
| weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank | |
| if self.rank: | |
| weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank | |
| return F.linear(input, weight, self.bias) | |
| def extra_repr(self): | |
| return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( | |
| self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank | |
| ) | |
| def modify_with_lora(transformer, config): | |
| for m_name, module in dict(transformer.named_modules()).items(): | |
| if re.fullmatch(config.lora_modules, m_name): | |
| for c_name, layer in dict(module.named_children()).items(): | |
| if re.fullmatch(config.lora_layers, c_name): | |
| assert isinstance( | |
| layer, nn.Linear | |
| ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." | |
| setattr( | |
| module, | |
| c_name, | |
| LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), | |
| ) | |
| return transformer | |
| def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True): | |
| # Load model and tokenizer | |
| if "ankh" in checkpoint : | |
| model = T5EncoderModel.from_pretrained(checkpoint) | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| elif "prot_t5" in checkpoint: | |
| # possible to load the half precision model (thanks to @pawel-rezo for pointing that out) | |
| if half_precision and deepspeed: | |
| #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) | |
| #model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16)#.to(torch.device('cuda') | |
| tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) | |
| model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) | |
| else: | |
| model = T5EncoderModel.from_pretrained(checkpoint) | |
| tokenizer = T5Tokenizer.from_pretrained(checkpoint) | |
| elif "ProstT5" in checkpoint: | |
| if half_precision and deepspeed: | |
| tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False) | |
| model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda')) | |
| else: | |
| model = T5EncoderModel.from_pretrained(checkpoint) | |
| tokenizer = T5Tokenizer.from_pretrained(checkpoint) | |
| # Create new Classifier model with PT5 dimensions | |
| class_config=ClassConfig(num_labels=num_labels) | |
| class_model=T5EncoderForTokenClassification(model.config,class_config) | |
| # Set encoder and embedding weights to checkpoint weights | |
| class_model.shared=model.shared | |
| class_model.encoder=model.encoder | |
| # Delete the checkpoint model | |
| model=class_model | |
| del class_model | |
| if full == True: | |
| return model, tokenizer | |
| # Print number of trainable parameters | |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| print("T5_Classfier\nTrainable Parameter: "+ str(params)) | |
| if custom_lora: | |
| #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed | |
| # Add model modification lora | |
| config = CustomLoRAConfig() | |
| # Add LoRA layers | |
| model = modify_with_lora(model, config) | |
| # Freeze Embeddings and Encoder (except LoRA) | |
| for (param_name, param) in model.shared.named_parameters(): | |
| param.requires_grad = False | |
| for (param_name, param) in model.encoder.named_parameters(): | |
| param.requires_grad = False | |
| for (param_name, param) in model.named_parameters(): | |
| if re.fullmatch(config.trainable_param_names, param_name): | |
| param.requires_grad = True | |
| else: | |
| # lora modification | |
| peft_config = LoraConfig( | |
| r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"] | |
| ) | |
| model = inject_adapter_in_model(peft_config, model) | |
| # Unfreeze the prediction head | |
| for (param_name, param) in model.classifier.named_parameters(): | |
| param.requires_grad = True | |
| # Print trainable Parameter | |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| print("T5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") | |
| return model, tokenizer | |
| class EsmForTokenClassificationCustom(EsmPreTrainedModel): | |
| _keys_to_ignore_on_load_unexpected = [r"pooler"] | |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"cnn", r"ffn", r"transformer"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.esm = EsmModel(config, add_pooling_layer=False) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| if cnn_head: | |
| self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1) | |
| self.classifier = nn.Linear(512, config.num_labels) | |
| elif ffn_head: | |
| # Multi-layer feed-forward network (FFN) as an alternative head | |
| self.ffn = nn.Sequential( | |
| nn.Linear(config.hidden_size, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, config.num_labels) | |
| ) | |
| elif transformer_head: | |
| # Transformer layer as an alternative head | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| else: | |
| # Default classification head | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, TokenClassifierOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.esm( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| sequence_output = self.dropout(sequence_output) | |
| if cnn_head: | |
| sequence_output = sequence_output.transpose(1, 2) | |
| sequence_output = self.cnn(sequence_output) | |
| sequence_output = sequence_output.transpose(1, 2) | |
| logits = self.classifier(sequence_output) | |
| elif ffn_head: | |
| logits = self.ffn(sequence_output) | |
| elif transformer_head: | |
| # Apply transformer encoder for the transformer head | |
| sequence_output = self.transformer_encoder(sequence_output) | |
| logits = self.classifier(sequence_output) | |
| else: | |
| logits = self.classifier(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| active_loss = attention_mask.view(-1) == 1 | |
| active_logits = logits.view(-1, self.num_labels) | |
| active_labels = torch.where( | |
| active_loss, labels.view(-1), torch.tensor(-100).type_as(labels) | |
| ) | |
| valid_logits = active_logits[active_labels != -100] | |
| valid_labels = active_labels[active_labels != -100] | |
| valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0') | |
| loss = loss_fct(valid_logits, valid_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return TokenClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| # based on transformers DataCollatorForTokenClassification | |
| class DataCollatorForTokenClassificationESM(DataCollatorMixin): | |
| """ | |
| Data collator that will dynamically pad the inputs received, as well as the labels. | |
| Args: | |
| tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): | |
| The tokenizer used for encoding the data. | |
| padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): | |
| Select a strategy to pad the returned sequences (according to the model's padding side and padding index) | |
| among: | |
| - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single | |
| sequence is provided). | |
| - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum | |
| acceptable input length for the model if that argument is not provided. | |
| - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). | |
| max_length (`int`, *optional*): | |
| Maximum length of the returned list and optionally padding length (see above). | |
| pad_to_multiple_of (`int`, *optional*): | |
| If set will pad the sequence to a multiple of the provided value. | |
| This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= | |
| 7.5 (Volta). | |
| label_pad_token_id (`int`, *optional*, defaults to -100): | |
| The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). | |
| return_tensors (`str`): | |
| The type of Tensor to return. Allowable values are "np", "pt" and "tf". | |
| """ | |
| tokenizer: PreTrainedTokenizerBase | |
| padding: Union[bool, str, PaddingStrategy] = True | |
| max_length: Optional[int] = None | |
| pad_to_multiple_of: Optional[int] = None | |
| label_pad_token_id: int = -100 | |
| return_tensors: str = "pt" | |
| def torch_call(self, features): | |
| import torch | |
| label_name = "label" if "label" in features[0].keys() else "labels" | |
| labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None | |
| no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] | |
| batch = self.tokenizer.pad( | |
| no_labels_features, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors="pt", | |
| ) | |
| if labels is None: | |
| return batch | |
| sequence_length = batch["input_ids"].shape[1] | |
| padding_side = self.tokenizer.padding_side | |
| def to_list(tensor_or_iterable): | |
| if isinstance(tensor_or_iterable, torch.Tensor): | |
| return tensor_or_iterable.tolist() | |
| return list(tensor_or_iterable) | |
| if padding_side == "right": | |
| batch[label_name] = [ | |
| # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels | |
| # changed to pad the special tokens at the beginning and end of the sequence | |
| [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels | |
| ] | |
| else: | |
| batch[label_name] = [ | |
| [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels | |
| ] | |
| batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float) | |
| return batch | |
| def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): | |
| """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" | |
| import torch | |
| # Tensorize if necessary. | |
| if isinstance(examples[0], (list, tuple, np.ndarray)): | |
| examples = [torch.tensor(e, dtype=torch.long) for e in examples] | |
| length_of_first = examples[0].size(0) | |
| # Check if padding is necessary. | |
| are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) | |
| if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): | |
| return torch.stack(examples, dim=0) | |
| # If yes, check if we have a `pad_token`. | |
| if tokenizer._pad_token is None: | |
| raise ValueError( | |
| "You are attempting to pad samples but the tokenizer you are using" | |
| f" ({tokenizer.__class__.__name__}) does not have a pad token." | |
| ) | |
| # Creating the full tensor and filling it with our data. | |
| max_length = max(x.size(0) for x in examples) | |
| if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): | |
| max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
| result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) | |
| for i, example in enumerate(examples): | |
| if tokenizer.padding_side == "right": | |
| result[i, : example.shape[0]] = example | |
| else: | |
| result[i, -example.shape[0] :] = example | |
| return result | |
| def tolist(x): | |
| if isinstance(x, list): | |
| return x | |
| elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import | |
| x = x.numpy() | |
| return x.tolist() | |
| #load ESM2 models | |
| def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True): | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| if half_precision and deepspeed: | |
| model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, | |
| num_labels = num_labels, | |
| ignore_mismatched_sizes=True, | |
| torch_dtype = torch.float16) | |
| else: | |
| model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, | |
| num_labels = num_labels, | |
| ignore_mismatched_sizes=True) | |
| if full == True: | |
| return model, tokenizer | |
| peft_config = LoraConfig( | |
| r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] | |
| ) | |
| model = inject_adapter_in_model(peft_config, model) | |
| #model.gradient_checkpointing_enable() | |
| # Unfreeze the prediction head | |
| for (param_name, param) in model.classifier.named_parameters(): | |
| param.requires_grad = True | |
| return model, tokenizer | |
| def load_model(checkpoint,max_length): | |
| #checkpoint='ThorbenF/prot_t5_xl_uniref50' | |
| #best_model_path='ThorbenF/prot_t5_xl_uniref50/cpt.pth' | |
| full=False | |
| deepspeed=False | |
| mixed=False | |
| num_labels=2 | |
| print(checkpoint, num_labels, mixed, full, deepspeed) | |
| # Determine model type and load accordingly | |
| if "esm" in checkpoint: | |
| model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed) | |
| else: | |
| model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed) | |
| # Download the file | |
| local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth") | |
| # Load the best model state | |
| state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True) | |
| model.load_state_dict(state_dict) | |
| return model, tokenizer |