| from transformers import AutoTokenizer, AutoModel |
| import torch |
| from Bio import SeqIO |
| from io import StringIO |
| import numpy as np |
| import spaces |
|
|
| class EsmEmbedding: |
| def __init__(self, model_name="facebook/esm2_t36_3B_UR50D"): |
| """ |
| Initialize the ESM-2 model via Hugging Face Transformers. |
| """ |
| |
| |
| self.model_name = model_name |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = None |
| self.device = None |
| print(f"Initialized EsmEmbedding with {model_name} (model not loaded yet)") |
| |
| def _generate_embedding(self, sequence): |
| """ |
| Helper function to generate embeddings for a single sequence. |
| Should only be called within a GPU context where model is initialized. |
| """ |
| tokens = self.tokenizer(sequence, return_tensors="pt", add_special_tokens=True) |
| tokens = {k: v.to(self.device) for k, v in tokens.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model(**tokens, output_hidden_states=True) |
| hidden = outputs.hidden_states[36][0].detach().to(torch.float64) |
| |
| mean_embedding = hidden[1:-1].mean(dim=0) |
| cls_embedding = hidden[0] |
| return mean_embedding, cls_embedding |
| |
| @spaces.GPU(duration=128) |
| def predict(self, data): |
| """ |
| Generate embeddings for the provided data. |
| This method has GPU decoration for standalone use. |
| |
| Parameters: |
| - data: List of tuples (id, sequence) |
| Returns: |
| - mean embedding (excluding [CLS] and [EOS]) |
| - CLS token embedding |
| """ |
| |
| if self.model is None: |
| self.device = "cuda" |
| |
| self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True, use_safetensors=True).to(self.device,dtype=torch.float32) |
|
|
| self.model.eval() |
| print(f"Loaded {self.model_name} on {self.device}") |
| |
| _, sequences = zip(*data) |
| sequence = sequences[0] |
| return self._generate_embedding(sequence) |
| |
| @spaces.GPU(duration=128) |
| def predict_proteome(self, protein_fasta): |
| """ |
| Predicts embeddings for each protein sequence in the given FASTA string. |
| Loads the model once and processes all sequences within a single GPU context. |
| |
| Args: |
| protein_fasta (str): A FASTA string containing protein sequences. |
| Returns: |
| np.ndarray: The mean embedding vector across all protein sequences. |
| """ |
| |
| if self.model is None: |
| self.device = "cuda" |
| |
| self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True, use_safetensors=True).to(self.device) |
|
|
| self.model.eval() |
| print(f"Loaded {self.model_name} on {self.device}") |
| |
| protein_sequences = {} |
| for record in SeqIO.parse(StringIO(protein_fasta), "fasta"): |
| protein_sequences[record.id] = str(record.seq) |
| |
| embeddings = [] |
| |
| |
| for prot_id, sequence in protein_sequences.items(): |
| |
| _,mean_embedding = self._generate_embedding(sequence) |
| embeddings.append(mean_embedding) |
| |
| |
| embeddings_stack = torch.stack(embeddings, dim=0).to(torch.float64) |
| embeddings_stack = torch.nn.functional.normalize(embeddings_stack, p=2, dim=1) |
| embeddings_mean = torch.mean(embeddings_stack, dim=0) |
| print("Mean",embeddings_mean) |
| return embeddings_mean.cpu().numpy() |
|
|
| |
| |
| |
| |