# esm3bedding.py import os import torch from esm.models.esmc import ESMC from esm.sdk.api import ESMProtein, LogitsConfig from huggingface_hub import login from utils import get_logger from base import Featurizer logg = get_logger() class ESM3Featurizer(Featurizer): def __init__(self, save_dir: str, api_key: str, per_tok: bool = True): super().__init__("ESM3", 1152, save_dir=save_dir) self.per_tok = per_tok self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.client = None self._login(api_key) self._initialize_model() def _login(self, api_key: str): try: login(api_key) logg.info("Successfully logged into Hugging Face Hub.") except Exception as e: logg.error(f"Failed to log in to Hugging Face Hub: {e}") raise RuntimeError("Hugging Face login failed. Check your API key.") def _initialize_model(self): try: logg.info("Initializing ESMC model (esmc_600m)...") # First try normal online loading try: self.client = ESMC.from_pretrained("esmc_600m") self.client.to(self._device) logg.info("ESMC model loaded.") return except Exception as online_error: logg.warning(f"Online model loading failed: {online_error}") logg.info("Attempting offline mode (using local cache)...") # Fallback: Try offline mode using cached files import os os.environ["HF_HUB_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" try: self.client = ESMC.from_pretrained("esmc_600m", local_files_only=True) self.client.to(self._device) logg.info("ESMC model loaded from local cache (offline mode).") except Exception as offline_error: logg.error(f"Offline loading also failed: {offline_error}") logg.error("="*60) logg.error("ESMC MODEL NOT FOUND IN CACHE!") logg.error("Run this on a node with internet access to cache the model:") logg.error(" python -c \"from esm.models.esmc import ESMC; ESMC.from_pretrained('esmc_600m')\"") logg.error("="*60) raise RuntimeError("ESMC model not available. See error messages above.") except Exception as e: logg.error(f"Failed to load ESMC model: {e}") raise RuntimeError("ESMC model initialization failed.") def _transform(self, sequence: str) -> torch.Tensor: try: # REPLACE (not remove) invalid chars to preserve sequence length valid_aa = set('ACDEFGHIKLMNPQRSTVWY') clean_sequence = ''.join(c if c in valid_aa else 'A' for c in sequence.upper()) protein = ESMProtein(sequence=clean_sequence) protein_tensor = self.client.encode(protein) logits_config = LogitsConfig(sequence=True, return_embeddings=True) output = self.client.logits(protein_tensor, logits_config) embeddings = output.embeddings # shape => [1, L, D] or [L, D] if embeddings.dim() == 3 and embeddings.shape[0] == 1: embeddings = embeddings.squeeze(0) # => [L, D] if not self.per_tok: embeddings = embeddings.mean(dim=0) # => [D] return embeddings except Exception as e: logg.error(f"Error generating embeddings for sequence: {e}") return None