ilkhamfy commited on
Commit
8283feb
·
verified ·
1 Parent(s): 80cf349

Upload embedder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. embedder.py +42 -0
embedder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ESM-2 embedding extractor — loaded once as a singleton."""
2
+ import numpy as np
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+ MODEL_NAME = "facebook/esm2_t6_8M_UR50D" # 8M params, 320-dim, fast
7
+
8
+ _tokenizer = None
9
+ _model = None
10
+
11
+ def _load():
12
+ global _tokenizer, _model
13
+ if _tokenizer is None:
14
+ print("Loading ESM-2...", flush=True)
15
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ _model = AutoModel.from_pretrained(MODEL_NAME)
17
+ _model.eval()
18
+ print("ESM-2 loaded.", flush=True)
19
+
20
+ def get_embeddings(sequences: list[str], batch_size: int = 32) -> np.ndarray:
21
+ """
22
+ Returns (N, 320) float32 array of mean-pooled ESM-2 embeddings.
23
+ Processes in batches to avoid OOM.
24
+ """
25
+ _load()
26
+ all_embs = []
27
+ for i in range(0, len(sequences), batch_size):
28
+ batch = sequences[i:i + batch_size]
29
+ inputs = _tokenizer(
30
+ batch,
31
+ return_tensors="pt",
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=256,
35
+ )
36
+ with torch.no_grad():
37
+ outputs = _model(**inputs)
38
+ # Mean pool over sequence positions (excluding padding tokens)
39
+ mask = inputs["attention_mask"].unsqueeze(-1).float() # (B, L, 1)
40
+ emb = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) # (B, 320)
41
+ all_embs.append(emb.numpy())
42
+ return np.vstack(all_embs).astype(np.float32)