"""ESM-2 embedding extractor — loaded once as a singleton.""" import numpy as np import torch from transformers import AutoTokenizer, AutoModel MODEL_NAME = "facebook/esm2_t6_8M_UR50D" # 8M params, 320-dim, fast _tokenizer = None _model = None def _load(): global _tokenizer, _model if _tokenizer is None: print("Loading ESM-2...", flush=True) _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModel.from_pretrained(MODEL_NAME) _model.eval() print("ESM-2 loaded.", flush=True) def get_embeddings(sequences: list[str], batch_size: int = 32) -> np.ndarray: """ Returns (N, 320) float32 array of mean-pooled ESM-2 embeddings. Processes in batches to avoid OOM. """ _load() all_embs = [] for i in range(0, len(sequences), batch_size): batch = sequences[i:i + batch_size] inputs = _tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=256, ) with torch.no_grad(): outputs = _model(**inputs) # Mean pool over sequence positions (excluding padding tokens) mask = inputs["attention_mask"].unsqueeze(-1).float() # (B, L, 1) emb = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) # (B, 320) all_embs.append(emb.numpy()) return np.vstack(all_embs).astype(np.float32)