Spaces:
Sleeping
Sleeping
| """ESM-2 embedding extractor — loaded once as a singleton.""" | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| MODEL_NAME = "facebook/esm2_t6_8M_UR50D" # 8M params, 320-dim, fast | |
| _tokenizer = None | |
| _model = None | |
| def _load(): | |
| global _tokenizer, _model | |
| if _tokenizer is None: | |
| print("Loading ESM-2...", flush=True) | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| _model = AutoModel.from_pretrained(MODEL_NAME) | |
| _model.eval() | |
| print("ESM-2 loaded.", flush=True) | |
| def get_embeddings(sequences: list[str], batch_size: int = 32) -> np.ndarray: | |
| """ | |
| Returns (N, 320) float32 array of mean-pooled ESM-2 embeddings. | |
| Processes in batches to avoid OOM. | |
| """ | |
| _load() | |
| all_embs = [] | |
| for i in range(0, len(sequences), batch_size): | |
| batch = sequences[i:i + batch_size] | |
| inputs = _tokenizer( | |
| batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| ) | |
| with torch.no_grad(): | |
| outputs = _model(**inputs) | |
| # Mean pool over sequence positions (excluding padding tokens) | |
| mask = inputs["attention_mask"].unsqueeze(-1).float() # (B, L, 1) | |
| emb = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) # (B, 320) | |
| all_embs.append(emb.numpy()) | |
| return np.vstack(all_embs).astype(np.float32) | |