lhallee commited on
Commit
f56f605
·
verified ·
1 Parent(s): e21e4e2

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +280 -95
modeling_esm_plusplus.py CHANGED
@@ -23,18 +23,218 @@ inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
23
  dynamo.config.capture_scalar_outputs = True
24
  torch._dynamo.config.recompile_limit = 16
25
 
 
26
  import os
 
27
  import sqlite3
 
 
 
 
28
  import networkx as nx
29
  import numpy as np
30
  import torch
31
  from tqdm.auto import tqdm
32
- from typing import Callable, Dict, List, Optional, Set
33
  from torch.utils.data import DataLoader
34
  from torch.utils.data import Dataset as TorchDataset
35
  from transformers import PreTrainedTokenizerBase
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class Pooler:
39
  def __init__(self, pooling_types: List[str]) -> None:
40
  self.pooling_types = pooling_types
@@ -55,9 +255,6 @@ class Pooler:
55
  return maxed_attentions
56
 
57
  def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
58
- # Run PageRank on the attention matrix converted to a graph.
59
- # Raises exceptions if the graph doesn't match the token sequence or has no edges.
60
- # Returns the PageRank scores for each token node.
61
  G = self._convert_to_graph(attention_matrix)
62
  if G.number_of_nodes() != attention_matrix.shape[0]:
63
  raise Exception(
@@ -68,26 +265,20 @@ class Pooler:
68
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
69
 
70
  def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
71
- # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
72
- # Each element in the matrix represents a directed edge with a weight.
73
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
74
  return G
75
 
76
  def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
77
- # Remove keys where attention_mask is 0
78
  if attention_mask is not None:
79
  for k in list(dict_importance.keys()):
80
  if attention_mask[k] == 0:
81
  del dict_importance[k]
82
 
83
- #dict_importance[0] # remove cls
84
- #dict_importance[-1] # remove eos
85
  total = sum(dict_importance.values())
86
  return np.array([v / total for _, v in dict_importance.items()])
87
 
88
- def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
89
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
90
- # emb is (b, L, d), maxed_attentions is (b, L, L)
91
  emb_pooled = []
92
  for e, a, mask in zip(emb, maxed_attentions, attention_mask):
93
  dict_importance = self._page_rank(a)
@@ -97,58 +288,53 @@ class Pooler:
97
  pooled = torch.tensor(np.array(emb_pooled))
98
  return pooled
99
 
100
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
101
  if attention_mask is None:
102
  return emb.mean(dim=1)
103
  else:
104
  attention_mask = attention_mask.unsqueeze(-1)
105
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
106
 
107
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
108
  if attention_mask is None:
109
  return emb.max(dim=1).values
110
  else:
111
  mask = attention_mask.unsqueeze(-1).bool()
112
  return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
113
 
114
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
115
  if attention_mask is None:
116
  return emb.norm(dim=1, p=2)
117
  else:
118
  attention_mask = attention_mask.unsqueeze(-1)
119
  return (emb * attention_mask).norm(dim=1, p=2)
120
 
121
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
122
  if attention_mask is None:
123
  return emb.median(dim=1).values
124
  else:
125
  mask = attention_mask.unsqueeze(-1).bool()
126
  return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
127
-
128
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
129
  if attention_mask is None:
130
  return emb.std(dim=1)
131
  else:
132
- # Compute variance correctly over non-masked positions, then take sqrt
133
  var = self.var_pooling(emb, attention_mask, **kwargs)
134
  return torch.sqrt(var)
135
-
136
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
137
  if attention_mask is None:
138
  return emb.var(dim=1)
139
  else:
140
- # Correctly compute variance over only non-masked positions
141
- attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
142
- # Compute mean over non-masked positions
143
- mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
144
- mean = mean.unsqueeze(1) # (b, 1, d)
145
- # Compute squared differences from mean, only over non-masked positions
146
- squared_diff = (emb - mean) ** 2 # (b, L, d)
147
- # Sum squared differences over non-masked positions and divide by count
148
- var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
149
  return var
150
 
151
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
152
  return emb[:, 0, :]
153
 
154
  def __call__(
@@ -156,11 +342,11 @@ class Pooler:
156
  emb: torch.Tensor,
157
  attention_mask: Optional[torch.Tensor] = None,
158
  attentions: Optional[torch.Tensor] = None
159
- ) -> torch.Tensor: # [mean, max]
160
  final_emb: List[torch.Tensor] = []
161
  for pooling_type in self.pooling_types:
162
- final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
163
- return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
164
 
165
 
166
  class ProteinDataset(TorchDataset):
@@ -175,12 +361,6 @@ class ProteinDataset(TorchDataset):
175
  return self.sequences[idx]
176
 
177
 
178
- def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
179
- def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
180
- return tokenizer(sequences, return_tensors="pt", padding='longest')
181
- return _collate_fn
182
-
183
-
184
  def parse_fasta(fasta_path: str) -> List[str]:
185
  assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
186
  sequences = []
@@ -212,34 +392,19 @@ class EmbeddingMixin:
212
 
213
  def _read_sequences_from_db(self, db_path: str) -> Set[str]:
214
  """Read sequences from SQLite database."""
215
- sequences = []
216
- with sqlite3.connect(db_path) as conn:
217
  c = conn.cursor()
218
  c.execute("SELECT sequence FROM embeddings")
219
- while True:
220
- row = c.fetchone()
221
- if row is None:
222
- break
223
- sequences.append(row[0])
224
- return set(sequences)
225
 
226
  def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
227
  cursor = conn.cursor()
228
  cursor.execute(
229
  "CREATE TABLE IF NOT EXISTS embeddings ("
230
  "sequence TEXT PRIMARY KEY, "
231
- "embedding BLOB NOT NULL, "
232
- "shape TEXT, "
233
- "dtype TEXT"
234
  ")"
235
  )
236
- cursor.execute("PRAGMA table_info(embeddings)")
237
- rows = cursor.fetchall()
238
- column_names = [row[1] for row in rows]
239
- if "shape" not in column_names:
240
- cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
241
- if "dtype" not in column_names:
242
- cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
243
  conn.commit()
244
 
245
  def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
@@ -254,17 +419,17 @@ class EmbeddingMixin:
254
  def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
255
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
256
  loaded: Dict[str, torch.Tensor] = {}
257
- with sqlite3.connect(db_path) as conn:
258
  self._ensure_embeddings_table(conn)
259
  cursor = conn.cursor()
260
  if sequences is None:
261
- cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
262
  else:
263
  if len(sequences) == 0:
264
  return loaded
265
  placeholders = ",".join(["?"] * len(sequences))
266
  cursor.execute(
267
- f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
268
  tuple(sequences),
269
  )
270
 
@@ -272,18 +437,7 @@ class EmbeddingMixin:
272
  for row in rows:
273
  sequence = row[0]
274
  embedding_bytes = row[1]
275
- shape_text = row[2]
276
- dtype_text = row[3]
277
- assert shape_text is not None, "Missing shape metadata in embeddings table."
278
- assert dtype_text is not None, "Missing dtype metadata in embeddings table."
279
- shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
280
- assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
281
- expected_size = int(np.prod(shape_values))
282
- np_dtype = np.dtype(dtype_text)
283
- array = np.frombuffer(embedding_bytes, dtype=np_dtype)
284
- assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
285
- reshaped = array.copy().reshape(tuple(shape_values))
286
- loaded[sequence] = torch.from_numpy(reshaped)
287
  return loaded
288
 
289
  def embed_dataset(
@@ -302,6 +456,7 @@ class EmbeddingMixin:
302
  sql_db_path: str = 'embeddings.db',
303
  save_path: str = 'embeddings.pth',
304
  fasta_path: Optional[str] = None,
 
305
  **kwargs,
306
  ) -> Optional[Dict[str, torch.Tensor]]:
307
  """
@@ -324,8 +479,13 @@ class EmbeddingMixin:
324
  hidden_size = self.config.hidden_size
325
  pooler = Pooler(pooling_types) if not full_embeddings else None
326
  tokenizer_mode = tokenizer is not None
 
 
 
 
 
327
  if tokenizer_mode:
328
- collate_fn = build_collator(tokenizer)
329
  device = self.device
330
  else:
331
  collate_fn = None
@@ -342,17 +502,25 @@ class EmbeddingMixin:
342
  assert collate_fn is not None
343
  assert device is not None
344
  dataset = ProteinDataset(to_embed)
345
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
346
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
 
 
 
 
 
 
 
 
347
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
348
  input_ids = batch['input_ids'].to(device)
349
  attention_mask = batch['attention_mask'].to(device)
350
- residue_embeddings = self._embed(input_ids, attention_mask)
351
  yield seqs, residue_embeddings, attention_mask
352
  else:
353
  for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
354
  seqs = to_embed[batch_start:batch_start + batch_size]
355
- batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
356
  assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
357
  assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
358
  residue_embeddings, attention_mask = batch_output
@@ -360,30 +528,47 @@ class EmbeddingMixin:
360
  yield seqs, residue_embeddings, attention_mask
361
 
362
  if sql:
363
- conn = sqlite3.connect(sql_db_path)
 
 
 
 
364
  self._ensure_embeddings_table(conn)
365
- c = conn.cursor()
366
  already_embedded = self._read_sequences_from_db(sql_db_path)
367
  to_embed = [seq for seq in sequences if seq not in already_embedded]
368
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
369
  print(f"Embedding {len(to_embed)} new sequences")
370
  if len(to_embed) > 0:
371
- with torch.no_grad():
372
- for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
373
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
374
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
375
- if full_embeddings:
376
- emb = emb[mask.bool()].reshape(-1, hidden_size)
377
- emb_np = emb.cpu().numpy()
378
- emb_shape = ",".join([str(dim) for dim in emb_np.shape])
379
- emb_dtype = str(emb_np.dtype)
380
- c.execute(
381
- "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
382
- (seq, emb_np.tobytes(), emb_shape, emb_dtype),
383
- )
384
- if tokenizer_mode and (i + 1) % 100 == 0:
385
  conn.commit()
386
- conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  conn.close()
388
  return None
389
 
@@ -398,7 +583,7 @@ class EmbeddingMixin:
398
  print(f"Embedding {len(to_embed)} new sequences")
399
 
400
  if len(to_embed) > 0:
401
- with torch.no_grad():
402
  for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
403
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
404
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
 
23
  dynamo.config.capture_scalar_outputs = True
24
  torch._dynamo.config.recompile_limit = 16
25
 
26
+ import io
27
  import os
28
+ import queue
29
  import sqlite3
30
+ import struct
31
+ import threading
32
+ import time
33
+
34
  import networkx as nx
35
  import numpy as np
36
  import torch
37
  from tqdm.auto import tqdm
38
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
39
  from torch.utils.data import DataLoader
40
  from torch.utils.data import Dataset as TorchDataset
41
  from transformers import PreTrainedTokenizerBase
42
 
43
 
44
+ # Compact blob serialization constants
45
+ # Keep in sync with protify/utils.py and core/atlas/precomputed.py
46
+ _COMPACT_VERSION = 0x01
47
+ _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
48
+ _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
49
+ _CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32}
50
+
51
+
52
+ def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes:
53
+ """Serialize a tensor to compact binary format for SQLite blob storage.
54
+
55
+ Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes]
56
+ bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16)
57
+ but tagged with dtype_code=1 so they can be cast back on read.
58
+ Falls back to torch.save for unsupported dtypes.
59
+ """
60
+ t = tensor.cpu()
61
+ if t.dtype not in _DTYPE_TO_CODE:
62
+ buffer = io.BytesIO()
63
+ torch.save(t, buffer)
64
+ return buffer.getvalue()
65
+ dtype_code = _DTYPE_TO_CODE[t.dtype]
66
+
67
+ if t.dtype == torch.bfloat16:
68
+ raw = t.half().numpy().tobytes()
69
+ else:
70
+ raw = t.numpy().tobytes()
71
+
72
+ shape = t.shape
73
+ header = struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
74
+ return header + raw
75
+
76
+
77
+ def _compact_header(dtype: torch.dtype, shape: tuple) -> bytes:
78
+ """Build just the compact header for a given dtype and shape."""
79
+ dtype_code = _DTYPE_TO_CODE[dtype]
80
+ return struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
81
+
82
+
83
+ def batch_tensor_to_blobs(batch: torch.Tensor) -> List[bytes]:
84
+ """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors).
85
+
86
+ Builds the header once and slices raw bytes per row. Much faster than
87
+ per-row tensor_to_embedding_blob calls for uniform-shape batches.
88
+ """
89
+ assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}"
90
+ t = batch.cpu()
91
+ store_dtype = t.dtype
92
+ if t.dtype not in _DTYPE_TO_CODE:
93
+ return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])]
94
+
95
+ if t.dtype == torch.bfloat16:
96
+ arr = t.half().numpy()
97
+ store_dtype = torch.bfloat16
98
+ else:
99
+ arr = t.numpy()
100
+
101
+ row_shape = tuple(t.shape[1:])
102
+ header = _compact_header(store_dtype, row_shape)
103
+ raw = arr.tobytes()
104
+ stride = len(raw) // t.shape[0]
105
+ return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])]
106
+
107
+
108
+ def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
109
+ """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats."""
110
+ if len(blob) >= 6 and blob[0] == _COMPACT_VERSION:
111
+ dtype_code = blob[1]
112
+ ndim = struct.unpack_from('<i', blob, 2)[0]
113
+ shape = struct.unpack_from(f'<{ndim}i', blob, 6)
114
+ data_offset = 6 + 4 * ndim
115
+ np_dtype = _CODE_TO_NP_DTYPE[dtype_code]
116
+ arr = np.frombuffer(blob, dtype=np_dtype, offset=data_offset).copy().reshape(shape)
117
+ t = torch.from_numpy(arr)
118
+ target_dtype = _CODE_TO_DTYPE[dtype_code]
119
+ if target_dtype != t.dtype:
120
+ t = t.to(target_dtype)
121
+ return t
122
+
123
+ # Fallback: try torch.load (pickle format)
124
+ try:
125
+ buffer = io.BytesIO(blob)
126
+ return torch.load(buffer, map_location='cpu', weights_only=True)
127
+ except Exception:
128
+ pass
129
+
130
+ # Legacy fallback: raw float32 bytes with caller-supplied shape
131
+ assert fallback_shape is not None, "Cannot deserialize blob: unknown format and no fallback_shape provided."
132
+ arr = np.frombuffer(blob, dtype=np.float32).copy().reshape(fallback_shape)
133
+ return torch.from_numpy(arr)
134
+
135
+
136
+ def maybe_compile(model: torch.nn.Module, dynamic: bool = False) -> torch.nn.Module:
137
+ """Compile model with torch.compile if possible.
138
+
139
+ Skips compilation when dynamic=True (padding='longest') because
140
+ flex attention's create_block_mask is incompatible with dynamic shapes
141
+ under torch.compile, causing CUDA illegal memory access.
142
+ """
143
+ if dynamic:
144
+ print("Skipping torch.compile (dynamic shapes + flex attention incompatible)")
145
+ return model
146
+ try:
147
+ model = torch.compile(model)
148
+ print("Model compiled")
149
+ except Exception as e:
150
+ print(f"Skipping torch.compile: {e}")
151
+ return model
152
+
153
+
154
+ def build_collator(
155
+ tokenizer: PreTrainedTokenizerBase,
156
+ padding: str = 'max_length',
157
+ max_length: int = 512,
158
+ ) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
159
+ def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
160
+ kwargs: Dict[str, Any] = dict(
161
+ return_tensors="pt", padding=padding, truncation=True, max_length=max_length,
162
+ )
163
+ if padding != 'max_length':
164
+ kwargs['pad_to_multiple_of'] = 8
165
+ return tokenizer(sequences, **kwargs)
166
+ return _collate_fn
167
+
168
+
169
+ def _make_embedding_progress(
170
+ dataloader: DataLoader,
171
+ padding: str,
172
+ n_warmup: int = 3,
173
+ n_calibration: int = 5,
174
+ ) -> Iterator[Tuple[int, Any]]:
175
+ """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader).
176
+
177
+ When padding='max_length', all batches have uniform cost so plain tqdm works.
178
+ When padding='longest' (sorted longest-first), batch times vary dramatically.
179
+ In that case: yield warmup batches first (compiler warmup + OOM check on longest
180
+ sequences), then time mid-length calibration batches to estimate total ETA.
181
+
182
+ Keep in sync with protify/embedder.py and core/atlas/precomputed.py.
183
+ """
184
+ total = len(dataloader)
185
+ if padding == 'max_length' or total <= n_warmup + n_calibration:
186
+ for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'):
187
+ yield i, batch
188
+ return
189
+
190
+ dl_iter = iter(dataloader)
191
+
192
+ # Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first)
193
+ warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False)
194
+ for i in warmup_bar:
195
+ batch = next(dl_iter)
196
+ yield i, batch
197
+ warmup_bar.close()
198
+
199
+ # Phase 2: skip to middle of dataset for calibration timing
200
+ # We need to yield all intermediate batches too (they contain real data)
201
+ mid_start = total // 2
202
+ intermediate_bar = tqdm(
203
+ range(n_warmup, mid_start), desc='Embedding batches', leave=False,
204
+ )
205
+ for i in intermediate_bar:
206
+ batch = next(dl_iter)
207
+ yield i, batch
208
+ intermediate_bar.close()
209
+
210
+ # Phase 3: time calibration batches from the middle
211
+ calibration_times: List[float] = []
212
+ cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False)
213
+ for j in cal_bar:
214
+ t0 = time.perf_counter()
215
+ batch = next(dl_iter)
216
+ yield mid_start + j, batch
217
+ calibration_times.append(time.perf_counter() - t0)
218
+ cal_bar.close()
219
+
220
+ avg_time = sum(calibration_times) / len(calibration_times)
221
+ remaining_start = mid_start + n_calibration
222
+ remaining_count = total - remaining_start
223
+ estimated_total_seconds = avg_time * remaining_count
224
+
225
+ # Phase 4: remaining batches with calibrated ETA
226
+ main_bar = tqdm(
227
+ range(remaining_count),
228
+ desc='Embedding batches',
229
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
230
+ )
231
+ main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)')
232
+ for k in main_bar:
233
+ batch = next(dl_iter)
234
+ yield remaining_start + k, batch
235
+ main_bar.close()
236
+
237
+
238
  class Pooler:
239
  def __init__(self, pooling_types: List[str]) -> None:
240
  self.pooling_types = pooling_types
 
255
  return maxed_attentions
256
 
257
  def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
 
 
 
258
  G = self._convert_to_graph(attention_matrix)
259
  if G.number_of_nodes() != attention_matrix.shape[0]:
260
  raise Exception(
 
265
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
266
 
267
  def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
 
 
268
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
269
  return G
270
 
271
  def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
 
272
  if attention_mask is not None:
273
  for k in list(dict_importance.keys()):
274
  if attention_mask[k] == 0:
275
  del dict_importance[k]
276
 
 
 
277
  total = sum(dict_importance.values())
278
  return np.array([v / total for _, v in dict_importance.items()])
279
 
280
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
281
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
 
282
  emb_pooled = []
283
  for e, a, mask in zip(emb, maxed_attentions, attention_mask):
284
  dict_importance = self._page_rank(a)
 
288
  pooled = torch.tensor(np.array(emb_pooled))
289
  return pooled
290
 
291
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
292
  if attention_mask is None:
293
  return emb.mean(dim=1)
294
  else:
295
  attention_mask = attention_mask.unsqueeze(-1)
296
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
297
 
298
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
299
  if attention_mask is None:
300
  return emb.max(dim=1).values
301
  else:
302
  mask = attention_mask.unsqueeze(-1).bool()
303
  return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
304
 
305
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
306
  if attention_mask is None:
307
  return emb.norm(dim=1, p=2)
308
  else:
309
  attention_mask = attention_mask.unsqueeze(-1)
310
  return (emb * attention_mask).norm(dim=1, p=2)
311
 
312
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
313
  if attention_mask is None:
314
  return emb.median(dim=1).values
315
  else:
316
  mask = attention_mask.unsqueeze(-1).bool()
317
  return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
318
+
319
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
320
  if attention_mask is None:
321
  return emb.std(dim=1)
322
  else:
 
323
  var = self.var_pooling(emb, attention_mask, **kwargs)
324
  return torch.sqrt(var)
325
+
326
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
327
  if attention_mask is None:
328
  return emb.var(dim=1)
329
  else:
330
+ attention_mask = attention_mask.unsqueeze(-1)
331
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
332
+ mean = mean.unsqueeze(1)
333
+ squared_diff = (emb - mean) ** 2
334
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
 
 
 
 
335
  return var
336
 
337
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
338
  return emb[:, 0, :]
339
 
340
  def __call__(
 
342
  emb: torch.Tensor,
343
  attention_mask: Optional[torch.Tensor] = None,
344
  attentions: Optional[torch.Tensor] = None
345
+ ) -> torch.Tensor:
346
  final_emb: List[torch.Tensor] = []
347
  for pooling_type in self.pooling_types:
348
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
349
+ return torch.cat(final_emb, dim=-1)
350
 
351
 
352
  class ProteinDataset(TorchDataset):
 
361
  return self.sequences[idx]
362
 
363
 
 
 
 
 
 
 
364
  def parse_fasta(fasta_path: str) -> List[str]:
365
  assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
366
  sequences = []
 
392
 
393
  def _read_sequences_from_db(self, db_path: str) -> Set[str]:
394
  """Read sequences from SQLite database."""
395
+ with sqlite3.connect(db_path, timeout=30) as conn:
 
396
  c = conn.cursor()
397
  c.execute("SELECT sequence FROM embeddings")
398
+ return {row[0] for row in c.fetchall()}
 
 
 
 
 
399
 
400
  def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
401
  cursor = conn.cursor()
402
  cursor.execute(
403
  "CREATE TABLE IF NOT EXISTS embeddings ("
404
  "sequence TEXT PRIMARY KEY, "
405
+ "embedding BLOB NOT NULL"
 
 
406
  ")"
407
  )
 
 
 
 
 
 
 
408
  conn.commit()
409
 
410
  def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
 
419
  def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
420
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
421
  loaded: Dict[str, torch.Tensor] = {}
422
+ with sqlite3.connect(db_path, timeout=30) as conn:
423
  self._ensure_embeddings_table(conn)
424
  cursor = conn.cursor()
425
  if sequences is None:
426
+ cursor.execute("SELECT sequence, embedding FROM embeddings")
427
  else:
428
  if len(sequences) == 0:
429
  return loaded
430
  placeholders = ",".join(["?"] * len(sequences))
431
  cursor.execute(
432
+ f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})",
433
  tuple(sequences),
434
  )
435
 
 
437
  for row in rows:
438
  sequence = row[0]
439
  embedding_bytes = row[1]
440
+ loaded[sequence] = embedding_blob_to_tensor(embedding_bytes)
 
 
 
 
 
 
 
 
 
 
 
441
  return loaded
442
 
443
  def embed_dataset(
 
456
  sql_db_path: str = 'embeddings.db',
457
  save_path: str = 'embeddings.pth',
458
  fasta_path: Optional[str] = None,
459
+ padding: str = 'max_length',
460
  **kwargs,
461
  ) -> Optional[Dict[str, torch.Tensor]]:
462
  """
 
479
  hidden_size = self.config.hidden_size
480
  pooler = Pooler(pooling_types) if not full_embeddings else None
481
  tokenizer_mode = tokenizer is not None
482
+
483
+ # Resolve padding and compilation
484
+ dynamic = padding == 'longest'
485
+ compiled_model = maybe_compile(self, dynamic=dynamic)
486
+
487
  if tokenizer_mode:
488
+ collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len)
489
  device = self.device
490
  else:
491
  collate_fn = None
 
502
  assert collate_fn is not None
503
  assert device is not None
504
  dataset = ProteinDataset(to_embed)
505
+ dataloader = DataLoader(
506
+ dataset,
507
+ batch_size=batch_size,
508
+ num_workers=num_workers,
509
+ prefetch_factor=2 if num_workers > 0 else None,
510
+ collate_fn=collate_fn,
511
+ shuffle=False,
512
+ pin_memory=True,
513
+ )
514
+ for i, batch in _make_embedding_progress(dataloader, padding):
515
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
516
  input_ids = batch['input_ids'].to(device)
517
  attention_mask = batch['attention_mask'].to(device)
518
+ residue_embeddings = compiled_model._embed(input_ids, attention_mask)
519
  yield seqs, residue_embeddings, attention_mask
520
  else:
521
  for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
522
  seqs = to_embed[batch_start:batch_start + batch_size]
523
+ batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs)
524
  assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
525
  assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
526
  residue_embeddings, attention_mask = batch_output
 
528
  yield seqs, residue_embeddings, attention_mask
529
 
530
  if sql:
531
+ conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
532
+ conn.execute('PRAGMA journal_mode=WAL')
533
+ conn.execute('PRAGMA busy_timeout=30000')
534
+ conn.execute('PRAGMA synchronous=OFF')
535
+ conn.execute('PRAGMA cache_size=-64000')
536
  self._ensure_embeddings_table(conn)
 
537
  already_embedded = self._read_sequences_from_db(sql_db_path)
538
  to_embed = [seq for seq in sequences if seq not in already_embedded]
539
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
540
  print(f"Embedding {len(to_embed)} new sequences")
541
  if len(to_embed) > 0:
542
+ sql_queue: queue.Queue = queue.Queue(maxsize=4)
543
+
544
+ def _sql_writer():
545
+ wc = conn.cursor()
546
+ while True:
547
+ item = sql_queue.get()
548
+ if item is None:
549
+ break
550
+ wc.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
551
+ if sql_queue.qsize() == 0:
 
 
 
 
552
  conn.commit()
553
+ conn.commit()
554
+
555
+ sql_writer_thread = threading.Thread(target=_sql_writer, daemon=True)
556
+ sql_writer_thread.start()
557
+
558
+ with torch.inference_mode():
559
+ for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
560
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
561
+ if full_embeddings:
562
+ batch_rows = []
563
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
564
+ batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
565
+ else:
566
+ blobs = batch_tensor_to_blobs(embeddings)
567
+ batch_rows = list(zip(seqs, blobs))
568
+ sql_queue.put(batch_rows)
569
+
570
+ sql_queue.put(None)
571
+ sql_writer_thread.join()
572
  conn.close()
573
  return None
574
 
 
583
  print(f"Embedding {len(to_embed)} new sequences")
584
 
585
  if len(to_embed) > 0:
586
+ with torch.inference_mode():
587
  for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
588
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
589
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):