peppareto-api / embedder.py
ilkhamfy's picture
Upload embedder.py with huggingface_hub
8283feb verified
"""ESM-2 embedding extractor — loaded once as a singleton."""
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
MODEL_NAME = "facebook/esm2_t6_8M_UR50D" # 8M params, 320-dim, fast
_tokenizer = None
_model = None
def _load():
global _tokenizer, _model
if _tokenizer is None:
print("Loading ESM-2...", flush=True)
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModel.from_pretrained(MODEL_NAME)
_model.eval()
print("ESM-2 loaded.", flush=True)
def get_embeddings(sequences: list[str], batch_size: int = 32) -> np.ndarray:
"""
Returns (N, 320) float32 array of mean-pooled ESM-2 embeddings.
Processes in batches to avoid OOM.
"""
_load()
all_embs = []
for i in range(0, len(sequences), batch_size):
batch = sequences[i:i + batch_size]
inputs = _tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
)
with torch.no_grad():
outputs = _model(**inputs)
# Mean pool over sequence positions (excluding padding tokens)
mask = inputs["attention_mask"].unsqueeze(-1).float() # (B, L, 1)
emb = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) # (B, 320)
all_embs.append(emb.numpy())
return np.vstack(all_embs).astype(np.float32)