tmprot / helpers.py
GitLab CI
Latest changes
b8c7219
############################################ IMPORTS ###########################################################
import torch
from typing import Tuple, List, Dict
from transformers import AutoTokenizer, EsmForSequenceClassification
from peft import PeftModel
from Bio import SeqIO
from io import StringIO
##################################################################################################################
def load_model(
model_name: str, path_model: str, device: str
) -> Tuple[torch.nn.Module, AutoTokenizer]:
"""
Load the ESM model and the PEFT LoRA adapter, set to eval mode and freeze parameters.
Loading is done on-the-fly: take pre-trained ESM-2 and apply adapters (PeftModel.from_pretrained).
Args:
model_name (str): Name of the base ESM model (e.g., 'esm2_t33_650M_UR50D').
path_model (str): Path to the fine-tuned LoRA adapter.
Returns:
Tuple[torch.nn.Module, AutoTokenizer]: Loaded model and tokenizer.
"""
esm_model = f"facebook/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(esm_model)
base_model = EsmForSequenceClassification.from_pretrained(
esm_model, num_labels=1
).to(device)
# peft_config = PeftConfig.from_pretrained(str(path_model), local_files_only=True)
model = PeftModel.from_pretrained(base_model, str(path_model), is_local=True, local_files_only=True).to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
return model, tokenizer
def parse_fasta_string(fasta_str: str):
"""Parse FASTA string into list of dicts with id and sequence."""
handle = StringIO(fasta_str)
return [{"id": rec.id, "sequence": str(rec.seq)} for rec in SeqIO.parse(handle, "fasta")]