hievi / esm_utils.py
pswap's picture
print
c1bed2a
from transformers import AutoTokenizer, AutoModel
import torch
from Bio import SeqIO
from io import StringIO
import numpy as np
import spaces
class EsmEmbedding:
def __init__(self, model_name="facebook/esm2_t36_3B_UR50D"):
"""
Initialize the ESM-2 model via Hugging Face Transformers.
"""
# Only initialize the model name and tokenizer
# Model initialization will happen inside the GPU-decorated method
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = None # Model will be loaded in the GPU context
self.device = None # Device will be set in the GPU context
print(f"Initialized EsmEmbedding with {model_name} (model not loaded yet)")
def _generate_embedding(self, sequence):
"""
Helper function to generate embeddings for a single sequence.
Should only be called within a GPU context where model is initialized.
"""
tokens = self.tokenizer(sequence, return_tensors="pt", add_special_tokens=True)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens, output_hidden_states=True)
hidden = outputs.hidden_states[36][0].detach().to(torch.float64) # shape: [seq_len, hidden_dim]
mean_embedding = hidden[1:-1].mean(dim=0) # mean over non-[CLS]/[EOS]
cls_embedding = hidden[0] # CLS token
return mean_embedding, cls_embedding
@spaces.GPU(duration=128)
def predict(self, data):
"""
Generate embeddings for the provided data.
This method has GPU decoration for standalone use.
Parameters:
- data: List of tuples (id, sequence)
Returns:
- mean embedding (excluding [CLS] and [EOS])
- CLS token embedding
"""
# Initialize model if not already done
if self.model is None:
self.device = "cuda"
#self.model = AutoModel.from_pretrained(self.model_name).to(self.device,dtype=torch.float32)
self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True, use_safetensors=True).to(self.device,dtype=torch.float32)
self.model.eval()
print(f"Loaded {self.model_name} on {self.device}")
_, sequences = zip(*data)
sequence = sequences[0] # assume batch size = 1
return self._generate_embedding(sequence)
@spaces.GPU(duration=128)
def predict_proteome(self, protein_fasta):
"""
Predicts embeddings for each protein sequence in the given FASTA string.
Loads the model once and processes all sequences within a single GPU context.
Args:
protein_fasta (str): A FASTA string containing protein sequences.
Returns:
np.ndarray: The mean embedding vector across all protein sequences.
"""
# Initialize model here once for all sequences
if self.model is None:
self.device = "cuda"
#self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True, use_safetensors=True).to(self.device)
self.model.eval()
print(f"Loaded {self.model_name} on {self.device}")
protein_sequences = {}
for record in SeqIO.parse(StringIO(protein_fasta), "fasta"):
protein_sequences[record.id] = str(record.seq)
embeddings = []
# Process all sequences within this single GPU context
for prot_id, sequence in protein_sequences.items():
# Use the helper method directly - no GPU context switch
_,mean_embedding = self._generate_embedding(sequence)
embeddings.append(mean_embedding)
# Stack embeddings and normalize
embeddings_stack = torch.stack(embeddings, dim=0).to(torch.float64)
embeddings_stack = torch.nn.functional.normalize(embeddings_stack, p=2, dim=1)
embeddings_mean = torch.mean(embeddings_stack, dim=0)
print("Mean",embeddings_mean)
return embeddings_mean.cpu().numpy() # Move to CPU and convert to numpy
# Example usage:
# esm_model = EsmEmbedding()
# fasta_string = ">protein1\nMKLSTVLVLLLAGALATLVTPAPGS\n>protein2\nMTVLDLSPGGDQTLAHRLSRSSAPGSPRT"
# embeddings = esm_model.predict_proteome(fasta_string)