| |
| import torch |
| from typing import Tuple, List, Dict |
| from transformers import AutoTokenizer, EsmForSequenceClassification |
| from peft import PeftModel |
| from Bio import SeqIO |
| from io import StringIO |
| |
|
|
| def load_model( |
| model_name: str, path_model: str, device: str |
| ) -> Tuple[torch.nn.Module, AutoTokenizer]: |
| """ |
| Load the ESM model and the PEFT LoRA adapter, set to eval mode and freeze parameters. |
| Loading is done on-the-fly: take pre-trained ESM-2 and apply adapters (PeftModel.from_pretrained). |
| Args: |
| model_name (str): Name of the base ESM model (e.g., 'esm2_t33_650M_UR50D'). |
| path_model (str): Path to the fine-tuned LoRA adapter. |
| |
| Returns: |
| Tuple[torch.nn.Module, AutoTokenizer]: Loaded model and tokenizer. |
| """ |
| esm_model = f"facebook/{model_name}" |
| tokenizer = AutoTokenizer.from_pretrained(esm_model) |
| base_model = EsmForSequenceClassification.from_pretrained( |
| esm_model, num_labels=1 |
| ).to(device) |
| |
| model = PeftModel.from_pretrained(base_model, str(path_model), is_local=True, local_files_only=True).to(device) |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
| return model, tokenizer |
|
|
| def parse_fasta_string(fasta_str: str): |
| """Parse FASTA string into list of dicts with id and sequence.""" |
| handle = StringIO(fasta_str) |
| return [{"id": rec.id, "sequence": str(rec.seq)} for rec in SeqIO.parse(handle, "fasta")] |
|
|