File size: 1,757 Bytes
b8c7219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
############################################  IMPORTS  ###########################################################
import torch
from typing import Tuple, List, Dict
from transformers import AutoTokenizer, EsmForSequenceClassification
from peft import PeftModel
from Bio import SeqIO
from io import StringIO
##################################################################################################################

def load_model(
    model_name: str, path_model: str, device: str
) -> Tuple[torch.nn.Module, AutoTokenizer]:
    """
    Load the ESM model and the PEFT LoRA adapter, set to eval mode and freeze parameters.
    Loading is done on-the-fly: take pre-trained ESM-2 and apply adapters (PeftModel.from_pretrained).
    Args:
        model_name (str): Name of the base ESM model (e.g., 'esm2_t33_650M_UR50D').
        path_model (str): Path to the fine-tuned LoRA adapter.

    Returns:
        Tuple[torch.nn.Module, AutoTokenizer]: Loaded model and tokenizer.
    """
    esm_model = f"facebook/{model_name}"
    tokenizer = AutoTokenizer.from_pretrained(esm_model)
    base_model = EsmForSequenceClassification.from_pretrained(
        esm_model, num_labels=1
    ).to(device)
    # peft_config = PeftConfig.from_pretrained(str(path_model), local_files_only=True)
    model = PeftModel.from_pretrained(base_model, str(path_model), is_local=True, local_files_only=True).to(device)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model, tokenizer

def parse_fasta_string(fasta_str: str):
    """Parse FASTA string into list of dicts with id and sequence."""
    handle = StringIO(fasta_str)
    return [{"id": rec.id, "sequence": str(rec.seq)} for rec in SeqIO.parse(handle, "fasta")]