Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +280 -95
modeling_esm_plusplus.py
CHANGED
|
@@ -23,18 +23,218 @@ inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
|
|
| 23 |
dynamo.config.capture_scalar_outputs = True
|
| 24 |
torch._dynamo.config.recompile_limit = 16
|
| 25 |
|
|
|
|
| 26 |
import os
|
|
|
|
| 27 |
import sqlite3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
import networkx as nx
|
| 29 |
import numpy as np
|
| 30 |
import torch
|
| 31 |
from tqdm.auto import tqdm
|
| 32 |
-
from typing import Callable, Dict, List, Optional, Set
|
| 33 |
from torch.utils.data import DataLoader
|
| 34 |
from torch.utils.data import Dataset as TorchDataset
|
| 35 |
from transformers import PreTrainedTokenizerBase
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
class Pooler:
|
| 39 |
def __init__(self, pooling_types: List[str]) -> None:
|
| 40 |
self.pooling_types = pooling_types
|
|
@@ -55,9 +255,6 @@ class Pooler:
|
|
| 55 |
return maxed_attentions
|
| 56 |
|
| 57 |
def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
|
| 58 |
-
# Run PageRank on the attention matrix converted to a graph.
|
| 59 |
-
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
|
| 60 |
-
# Returns the PageRank scores for each token node.
|
| 61 |
G = self._convert_to_graph(attention_matrix)
|
| 62 |
if G.number_of_nodes() != attention_matrix.shape[0]:
|
| 63 |
raise Exception(
|
|
@@ -68,26 +265,20 @@ class Pooler:
|
|
| 68 |
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
| 69 |
|
| 70 |
def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
|
| 71 |
-
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
|
| 72 |
-
# Each element in the matrix represents a directed edge with a weight.
|
| 73 |
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
| 74 |
return G
|
| 75 |
|
| 76 |
def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
|
| 77 |
-
# Remove keys where attention_mask is 0
|
| 78 |
if attention_mask is not None:
|
| 79 |
for k in list(dict_importance.keys()):
|
| 80 |
if attention_mask[k] == 0:
|
| 81 |
del dict_importance[k]
|
| 82 |
|
| 83 |
-
#dict_importance[0] # remove cls
|
| 84 |
-
#dict_importance[-1] # remove eos
|
| 85 |
total = sum(dict_importance.values())
|
| 86 |
return np.array([v / total for _, v in dict_importance.items()])
|
| 87 |
|
| 88 |
-
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 89 |
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
| 90 |
-
# emb is (b, L, d), maxed_attentions is (b, L, L)
|
| 91 |
emb_pooled = []
|
| 92 |
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
|
| 93 |
dict_importance = self._page_rank(a)
|
|
@@ -97,58 +288,53 @@ class Pooler:
|
|
| 97 |
pooled = torch.tensor(np.array(emb_pooled))
|
| 98 |
return pooled
|
| 99 |
|
| 100 |
-
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 101 |
if attention_mask is None:
|
| 102 |
return emb.mean(dim=1)
|
| 103 |
else:
|
| 104 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 105 |
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 106 |
|
| 107 |
-
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 108 |
if attention_mask is None:
|
| 109 |
return emb.max(dim=1).values
|
| 110 |
else:
|
| 111 |
mask = attention_mask.unsqueeze(-1).bool()
|
| 112 |
return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
|
| 113 |
|
| 114 |
-
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 115 |
if attention_mask is None:
|
| 116 |
return emb.norm(dim=1, p=2)
|
| 117 |
else:
|
| 118 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 119 |
return (emb * attention_mask).norm(dim=1, p=2)
|
| 120 |
|
| 121 |
-
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 122 |
if attention_mask is None:
|
| 123 |
return emb.median(dim=1).values
|
| 124 |
else:
|
| 125 |
mask = attention_mask.unsqueeze(-1).bool()
|
| 126 |
return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
|
| 127 |
-
|
| 128 |
-
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 129 |
if attention_mask is None:
|
| 130 |
return emb.std(dim=1)
|
| 131 |
else:
|
| 132 |
-
# Compute variance correctly over non-masked positions, then take sqrt
|
| 133 |
var = self.var_pooling(emb, attention_mask, **kwargs)
|
| 134 |
return torch.sqrt(var)
|
| 135 |
-
|
| 136 |
-
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 137 |
if attention_mask is None:
|
| 138 |
return emb.var(dim=1)
|
| 139 |
else:
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
# Compute squared differences from mean, only over non-masked positions
|
| 146 |
-
squared_diff = (emb - mean) ** 2 # (b, L, d)
|
| 147 |
-
# Sum squared differences over non-masked positions and divide by count
|
| 148 |
-
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
|
| 149 |
return var
|
| 150 |
|
| 151 |
-
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 152 |
return emb[:, 0, :]
|
| 153 |
|
| 154 |
def __call__(
|
|
@@ -156,11 +342,11 @@ class Pooler:
|
|
| 156 |
emb: torch.Tensor,
|
| 157 |
attention_mask: Optional[torch.Tensor] = None,
|
| 158 |
attentions: Optional[torch.Tensor] = None
|
| 159 |
-
) -> torch.Tensor:
|
| 160 |
final_emb: List[torch.Tensor] = []
|
| 161 |
for pooling_type in self.pooling_types:
|
| 162 |
-
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
|
| 163 |
-
return torch.cat(final_emb, dim=-1)
|
| 164 |
|
| 165 |
|
| 166 |
class ProteinDataset(TorchDataset):
|
|
@@ -175,12 +361,6 @@ class ProteinDataset(TorchDataset):
|
|
| 175 |
return self.sequences[idx]
|
| 176 |
|
| 177 |
|
| 178 |
-
def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
|
| 179 |
-
def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
|
| 180 |
-
return tokenizer(sequences, return_tensors="pt", padding='longest')
|
| 181 |
-
return _collate_fn
|
| 182 |
-
|
| 183 |
-
|
| 184 |
def parse_fasta(fasta_path: str) -> List[str]:
|
| 185 |
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
|
| 186 |
sequences = []
|
|
@@ -212,34 +392,19 @@ class EmbeddingMixin:
|
|
| 212 |
|
| 213 |
def _read_sequences_from_db(self, db_path: str) -> Set[str]:
|
| 214 |
"""Read sequences from SQLite database."""
|
| 215 |
-
|
| 216 |
-
with sqlite3.connect(db_path) as conn:
|
| 217 |
c = conn.cursor()
|
| 218 |
c.execute("SELECT sequence FROM embeddings")
|
| 219 |
-
|
| 220 |
-
row = c.fetchone()
|
| 221 |
-
if row is None:
|
| 222 |
-
break
|
| 223 |
-
sequences.append(row[0])
|
| 224 |
-
return set(sequences)
|
| 225 |
|
| 226 |
def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
|
| 227 |
cursor = conn.cursor()
|
| 228 |
cursor.execute(
|
| 229 |
"CREATE TABLE IF NOT EXISTS embeddings ("
|
| 230 |
"sequence TEXT PRIMARY KEY, "
|
| 231 |
-
"embedding BLOB NOT NULL
|
| 232 |
-
"shape TEXT, "
|
| 233 |
-
"dtype TEXT"
|
| 234 |
")"
|
| 235 |
)
|
| 236 |
-
cursor.execute("PRAGMA table_info(embeddings)")
|
| 237 |
-
rows = cursor.fetchall()
|
| 238 |
-
column_names = [row[1] for row in rows]
|
| 239 |
-
if "shape" not in column_names:
|
| 240 |
-
cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
|
| 241 |
-
if "dtype" not in column_names:
|
| 242 |
-
cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
|
| 243 |
conn.commit()
|
| 244 |
|
| 245 |
def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
|
|
@@ -254,17 +419,17 @@ class EmbeddingMixin:
|
|
| 254 |
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
|
| 255 |
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
|
| 256 |
loaded: Dict[str, torch.Tensor] = {}
|
| 257 |
-
with sqlite3.connect(db_path) as conn:
|
| 258 |
self._ensure_embeddings_table(conn)
|
| 259 |
cursor = conn.cursor()
|
| 260 |
if sequences is None:
|
| 261 |
-
cursor.execute("SELECT sequence, embedding
|
| 262 |
else:
|
| 263 |
if len(sequences) == 0:
|
| 264 |
return loaded
|
| 265 |
placeholders = ",".join(["?"] * len(sequences))
|
| 266 |
cursor.execute(
|
| 267 |
-
f"SELECT sequence, embedding
|
| 268 |
tuple(sequences),
|
| 269 |
)
|
| 270 |
|
|
@@ -272,18 +437,7 @@ class EmbeddingMixin:
|
|
| 272 |
for row in rows:
|
| 273 |
sequence = row[0]
|
| 274 |
embedding_bytes = row[1]
|
| 275 |
-
|
| 276 |
-
dtype_text = row[3]
|
| 277 |
-
assert shape_text is not None, "Missing shape metadata in embeddings table."
|
| 278 |
-
assert dtype_text is not None, "Missing dtype metadata in embeddings table."
|
| 279 |
-
shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
|
| 280 |
-
assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
|
| 281 |
-
expected_size = int(np.prod(shape_values))
|
| 282 |
-
np_dtype = np.dtype(dtype_text)
|
| 283 |
-
array = np.frombuffer(embedding_bytes, dtype=np_dtype)
|
| 284 |
-
assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
|
| 285 |
-
reshaped = array.copy().reshape(tuple(shape_values))
|
| 286 |
-
loaded[sequence] = torch.from_numpy(reshaped)
|
| 287 |
return loaded
|
| 288 |
|
| 289 |
def embed_dataset(
|
|
@@ -302,6 +456,7 @@ class EmbeddingMixin:
|
|
| 302 |
sql_db_path: str = 'embeddings.db',
|
| 303 |
save_path: str = 'embeddings.pth',
|
| 304 |
fasta_path: Optional[str] = None,
|
|
|
|
| 305 |
**kwargs,
|
| 306 |
) -> Optional[Dict[str, torch.Tensor]]:
|
| 307 |
"""
|
|
@@ -324,8 +479,13 @@ class EmbeddingMixin:
|
|
| 324 |
hidden_size = self.config.hidden_size
|
| 325 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
| 326 |
tokenizer_mode = tokenizer is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
if tokenizer_mode:
|
| 328 |
-
collate_fn = build_collator(tokenizer)
|
| 329 |
device = self.device
|
| 330 |
else:
|
| 331 |
collate_fn = None
|
|
@@ -342,17 +502,25 @@ class EmbeddingMixin:
|
|
| 342 |
assert collate_fn is not None
|
| 343 |
assert device is not None
|
| 344 |
dataset = ProteinDataset(to_embed)
|
| 345 |
-
dataloader = DataLoader(
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 348 |
input_ids = batch['input_ids'].to(device)
|
| 349 |
attention_mask = batch['attention_mask'].to(device)
|
| 350 |
-
residue_embeddings =
|
| 351 |
yield seqs, residue_embeddings, attention_mask
|
| 352 |
else:
|
| 353 |
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 354 |
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 355 |
-
batch_output =
|
| 356 |
assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
|
| 357 |
assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
|
| 358 |
residue_embeddings, attention_mask = batch_output
|
|
@@ -360,30 +528,47 @@ class EmbeddingMixin:
|
|
| 360 |
yield seqs, residue_embeddings, attention_mask
|
| 361 |
|
| 362 |
if sql:
|
| 363 |
-
conn = sqlite3.connect(sql_db_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
self._ensure_embeddings_table(conn)
|
| 365 |
-
c = conn.cursor()
|
| 366 |
already_embedded = self._read_sequences_from_db(sql_db_path)
|
| 367 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 368 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 369 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 370 |
if len(to_embed) > 0:
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
"INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
|
| 382 |
-
(seq, emb_np.tobytes(), emb_shape, emb_dtype),
|
| 383 |
-
)
|
| 384 |
-
if tokenizer_mode and (i + 1) % 100 == 0:
|
| 385 |
conn.commit()
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
conn.close()
|
| 388 |
return None
|
| 389 |
|
|
@@ -398,7 +583,7 @@ class EmbeddingMixin:
|
|
| 398 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 399 |
|
| 400 |
if len(to_embed) > 0:
|
| 401 |
-
with torch.
|
| 402 |
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
|
| 403 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 404 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
|
|
| 23 |
dynamo.config.capture_scalar_outputs = True
|
| 24 |
torch._dynamo.config.recompile_limit = 16
|
| 25 |
|
| 26 |
+
import io
|
| 27 |
import os
|
| 28 |
+
import queue
|
| 29 |
import sqlite3
|
| 30 |
+
import struct
|
| 31 |
+
import threading
|
| 32 |
+
import time
|
| 33 |
+
|
| 34 |
import networkx as nx
|
| 35 |
import numpy as np
|
| 36 |
import torch
|
| 37 |
from tqdm.auto import tqdm
|
| 38 |
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
|
| 39 |
from torch.utils.data import DataLoader
|
| 40 |
from torch.utils.data import Dataset as TorchDataset
|
| 41 |
from transformers import PreTrainedTokenizerBase
|
| 42 |
|
| 43 |
|
| 44 |
+
# Compact blob serialization constants
|
| 45 |
+
# Keep in sync with protify/utils.py and core/atlas/precomputed.py
|
| 46 |
+
_COMPACT_VERSION = 0x01
|
| 47 |
+
_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
|
| 48 |
+
_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
|
| 49 |
+
_CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes:
|
| 53 |
+
"""Serialize a tensor to compact binary format for SQLite blob storage.
|
| 54 |
+
|
| 55 |
+
Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes]
|
| 56 |
+
bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16)
|
| 57 |
+
but tagged with dtype_code=1 so they can be cast back on read.
|
| 58 |
+
Falls back to torch.save for unsupported dtypes.
|
| 59 |
+
"""
|
| 60 |
+
t = tensor.cpu()
|
| 61 |
+
if t.dtype not in _DTYPE_TO_CODE:
|
| 62 |
+
buffer = io.BytesIO()
|
| 63 |
+
torch.save(t, buffer)
|
| 64 |
+
return buffer.getvalue()
|
| 65 |
+
dtype_code = _DTYPE_TO_CODE[t.dtype]
|
| 66 |
+
|
| 67 |
+
if t.dtype == torch.bfloat16:
|
| 68 |
+
raw = t.half().numpy().tobytes()
|
| 69 |
+
else:
|
| 70 |
+
raw = t.numpy().tobytes()
|
| 71 |
+
|
| 72 |
+
shape = t.shape
|
| 73 |
+
header = struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
|
| 74 |
+
return header + raw
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _compact_header(dtype: torch.dtype, shape: tuple) -> bytes:
|
| 78 |
+
"""Build just the compact header for a given dtype and shape."""
|
| 79 |
+
dtype_code = _DTYPE_TO_CODE[dtype]
|
| 80 |
+
return struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def batch_tensor_to_blobs(batch: torch.Tensor) -> List[bytes]:
|
| 84 |
+
"""Serialize a batch of same-shape tensors to compact blobs (fast path for vectors).
|
| 85 |
+
|
| 86 |
+
Builds the header once and slices raw bytes per row. Much faster than
|
| 87 |
+
per-row tensor_to_embedding_blob calls for uniform-shape batches.
|
| 88 |
+
"""
|
| 89 |
+
assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}"
|
| 90 |
+
t = batch.cpu()
|
| 91 |
+
store_dtype = t.dtype
|
| 92 |
+
if t.dtype not in _DTYPE_TO_CODE:
|
| 93 |
+
return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])]
|
| 94 |
+
|
| 95 |
+
if t.dtype == torch.bfloat16:
|
| 96 |
+
arr = t.half().numpy()
|
| 97 |
+
store_dtype = torch.bfloat16
|
| 98 |
+
else:
|
| 99 |
+
arr = t.numpy()
|
| 100 |
+
|
| 101 |
+
row_shape = tuple(t.shape[1:])
|
| 102 |
+
header = _compact_header(store_dtype, row_shape)
|
| 103 |
+
raw = arr.tobytes()
|
| 104 |
+
stride = len(raw) // t.shape[0]
|
| 105 |
+
return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
| 109 |
+
"""Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats."""
|
| 110 |
+
if len(blob) >= 6 and blob[0] == _COMPACT_VERSION:
|
| 111 |
+
dtype_code = blob[1]
|
| 112 |
+
ndim = struct.unpack_from('<i', blob, 2)[0]
|
| 113 |
+
shape = struct.unpack_from(f'<{ndim}i', blob, 6)
|
| 114 |
+
data_offset = 6 + 4 * ndim
|
| 115 |
+
np_dtype = _CODE_TO_NP_DTYPE[dtype_code]
|
| 116 |
+
arr = np.frombuffer(blob, dtype=np_dtype, offset=data_offset).copy().reshape(shape)
|
| 117 |
+
t = torch.from_numpy(arr)
|
| 118 |
+
target_dtype = _CODE_TO_DTYPE[dtype_code]
|
| 119 |
+
if target_dtype != t.dtype:
|
| 120 |
+
t = t.to(target_dtype)
|
| 121 |
+
return t
|
| 122 |
+
|
| 123 |
+
# Fallback: try torch.load (pickle format)
|
| 124 |
+
try:
|
| 125 |
+
buffer = io.BytesIO(blob)
|
| 126 |
+
return torch.load(buffer, map_location='cpu', weights_only=True)
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
# Legacy fallback: raw float32 bytes with caller-supplied shape
|
| 131 |
+
assert fallback_shape is not None, "Cannot deserialize blob: unknown format and no fallback_shape provided."
|
| 132 |
+
arr = np.frombuffer(blob, dtype=np.float32).copy().reshape(fallback_shape)
|
| 133 |
+
return torch.from_numpy(arr)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def maybe_compile(model: torch.nn.Module, dynamic: bool = False) -> torch.nn.Module:
|
| 137 |
+
"""Compile model with torch.compile if possible.
|
| 138 |
+
|
| 139 |
+
Skips compilation when dynamic=True (padding='longest') because
|
| 140 |
+
flex attention's create_block_mask is incompatible with dynamic shapes
|
| 141 |
+
under torch.compile, causing CUDA illegal memory access.
|
| 142 |
+
"""
|
| 143 |
+
if dynamic:
|
| 144 |
+
print("Skipping torch.compile (dynamic shapes + flex attention incompatible)")
|
| 145 |
+
return model
|
| 146 |
+
try:
|
| 147 |
+
model = torch.compile(model)
|
| 148 |
+
print("Model compiled")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Skipping torch.compile: {e}")
|
| 151 |
+
return model
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def build_collator(
|
| 155 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 156 |
+
padding: str = 'max_length',
|
| 157 |
+
max_length: int = 512,
|
| 158 |
+
) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
|
| 159 |
+
def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
|
| 160 |
+
kwargs: Dict[str, Any] = dict(
|
| 161 |
+
return_tensors="pt", padding=padding, truncation=True, max_length=max_length,
|
| 162 |
+
)
|
| 163 |
+
if padding != 'max_length':
|
| 164 |
+
kwargs['pad_to_multiple_of'] = 8
|
| 165 |
+
return tokenizer(sequences, **kwargs)
|
| 166 |
+
return _collate_fn
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _make_embedding_progress(
|
| 170 |
+
dataloader: DataLoader,
|
| 171 |
+
padding: str,
|
| 172 |
+
n_warmup: int = 3,
|
| 173 |
+
n_calibration: int = 5,
|
| 174 |
+
) -> Iterator[Tuple[int, Any]]:
|
| 175 |
+
"""Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader).
|
| 176 |
+
|
| 177 |
+
When padding='max_length', all batches have uniform cost so plain tqdm works.
|
| 178 |
+
When padding='longest' (sorted longest-first), batch times vary dramatically.
|
| 179 |
+
In that case: yield warmup batches first (compiler warmup + OOM check on longest
|
| 180 |
+
sequences), then time mid-length calibration batches to estimate total ETA.
|
| 181 |
+
|
| 182 |
+
Keep in sync with protify/embedder.py and core/atlas/precomputed.py.
|
| 183 |
+
"""
|
| 184 |
+
total = len(dataloader)
|
| 185 |
+
if padding == 'max_length' or total <= n_warmup + n_calibration:
|
| 186 |
+
for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'):
|
| 187 |
+
yield i, batch
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
dl_iter = iter(dataloader)
|
| 191 |
+
|
| 192 |
+
# Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first)
|
| 193 |
+
warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False)
|
| 194 |
+
for i in warmup_bar:
|
| 195 |
+
batch = next(dl_iter)
|
| 196 |
+
yield i, batch
|
| 197 |
+
warmup_bar.close()
|
| 198 |
+
|
| 199 |
+
# Phase 2: skip to middle of dataset for calibration timing
|
| 200 |
+
# We need to yield all intermediate batches too (they contain real data)
|
| 201 |
+
mid_start = total // 2
|
| 202 |
+
intermediate_bar = tqdm(
|
| 203 |
+
range(n_warmup, mid_start), desc='Embedding batches', leave=False,
|
| 204 |
+
)
|
| 205 |
+
for i in intermediate_bar:
|
| 206 |
+
batch = next(dl_iter)
|
| 207 |
+
yield i, batch
|
| 208 |
+
intermediate_bar.close()
|
| 209 |
+
|
| 210 |
+
# Phase 3: time calibration batches from the middle
|
| 211 |
+
calibration_times: List[float] = []
|
| 212 |
+
cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False)
|
| 213 |
+
for j in cal_bar:
|
| 214 |
+
t0 = time.perf_counter()
|
| 215 |
+
batch = next(dl_iter)
|
| 216 |
+
yield mid_start + j, batch
|
| 217 |
+
calibration_times.append(time.perf_counter() - t0)
|
| 218 |
+
cal_bar.close()
|
| 219 |
+
|
| 220 |
+
avg_time = sum(calibration_times) / len(calibration_times)
|
| 221 |
+
remaining_start = mid_start + n_calibration
|
| 222 |
+
remaining_count = total - remaining_start
|
| 223 |
+
estimated_total_seconds = avg_time * remaining_count
|
| 224 |
+
|
| 225 |
+
# Phase 4: remaining batches with calibrated ETA
|
| 226 |
+
main_bar = tqdm(
|
| 227 |
+
range(remaining_count),
|
| 228 |
+
desc='Embedding batches',
|
| 229 |
+
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
|
| 230 |
+
)
|
| 231 |
+
main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)')
|
| 232 |
+
for k in main_bar:
|
| 233 |
+
batch = next(dl_iter)
|
| 234 |
+
yield remaining_start + k, batch
|
| 235 |
+
main_bar.close()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
class Pooler:
|
| 239 |
def __init__(self, pooling_types: List[str]) -> None:
|
| 240 |
self.pooling_types = pooling_types
|
|
|
|
| 255 |
return maxed_attentions
|
| 256 |
|
| 257 |
def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
|
|
|
|
|
|
|
|
|
|
| 258 |
G = self._convert_to_graph(attention_matrix)
|
| 259 |
if G.number_of_nodes() != attention_matrix.shape[0]:
|
| 260 |
raise Exception(
|
|
|
|
| 265 |
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
| 266 |
|
| 267 |
def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
|
|
|
|
|
|
|
| 268 |
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
| 269 |
return G
|
| 270 |
|
| 271 |
def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
|
|
|
|
| 272 |
if attention_mask is not None:
|
| 273 |
for k in list(dict_importance.keys()):
|
| 274 |
if attention_mask[k] == 0:
|
| 275 |
del dict_importance[k]
|
| 276 |
|
|
|
|
|
|
|
| 277 |
total = sum(dict_importance.values())
|
| 278 |
return np.array([v / total for _, v in dict_importance.items()])
|
| 279 |
|
| 280 |
+
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 281 |
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
|
|
|
| 282 |
emb_pooled = []
|
| 283 |
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
|
| 284 |
dict_importance = self._page_rank(a)
|
|
|
|
| 288 |
pooled = torch.tensor(np.array(emb_pooled))
|
| 289 |
return pooled
|
| 290 |
|
| 291 |
+
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 292 |
if attention_mask is None:
|
| 293 |
return emb.mean(dim=1)
|
| 294 |
else:
|
| 295 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 296 |
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 297 |
|
| 298 |
+
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 299 |
if attention_mask is None:
|
| 300 |
return emb.max(dim=1).values
|
| 301 |
else:
|
| 302 |
mask = attention_mask.unsqueeze(-1).bool()
|
| 303 |
return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
|
| 304 |
|
| 305 |
+
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 306 |
if attention_mask is None:
|
| 307 |
return emb.norm(dim=1, p=2)
|
| 308 |
else:
|
| 309 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 310 |
return (emb * attention_mask).norm(dim=1, p=2)
|
| 311 |
|
| 312 |
+
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 313 |
if attention_mask is None:
|
| 314 |
return emb.median(dim=1).values
|
| 315 |
else:
|
| 316 |
mask = attention_mask.unsqueeze(-1).bool()
|
| 317 |
return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
|
| 318 |
+
|
| 319 |
+
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 320 |
if attention_mask is None:
|
| 321 |
return emb.std(dim=1)
|
| 322 |
else:
|
|
|
|
| 323 |
var = self.var_pooling(emb, attention_mask, **kwargs)
|
| 324 |
return torch.sqrt(var)
|
| 325 |
+
|
| 326 |
+
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 327 |
if attention_mask is None:
|
| 328 |
return emb.var(dim=1)
|
| 329 |
else:
|
| 330 |
+
attention_mask = attention_mask.unsqueeze(-1)
|
| 331 |
+
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 332 |
+
mean = mean.unsqueeze(1)
|
| 333 |
+
squared_diff = (emb - mean) ** 2
|
| 334 |
+
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
return var
|
| 336 |
|
| 337 |
+
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 338 |
return emb[:, 0, :]
|
| 339 |
|
| 340 |
def __call__(
|
|
|
|
| 342 |
emb: torch.Tensor,
|
| 343 |
attention_mask: Optional[torch.Tensor] = None,
|
| 344 |
attentions: Optional[torch.Tensor] = None
|
| 345 |
+
) -> torch.Tensor:
|
| 346 |
final_emb: List[torch.Tensor] = []
|
| 347 |
for pooling_type in self.pooling_types:
|
| 348 |
+
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
|
| 349 |
+
return torch.cat(final_emb, dim=-1)
|
| 350 |
|
| 351 |
|
| 352 |
class ProteinDataset(TorchDataset):
|
|
|
|
| 361 |
return self.sequences[idx]
|
| 362 |
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
def parse_fasta(fasta_path: str) -> List[str]:
|
| 365 |
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
|
| 366 |
sequences = []
|
|
|
|
| 392 |
|
| 393 |
def _read_sequences_from_db(self, db_path: str) -> Set[str]:
|
| 394 |
"""Read sequences from SQLite database."""
|
| 395 |
+
with sqlite3.connect(db_path, timeout=30) as conn:
|
|
|
|
| 396 |
c = conn.cursor()
|
| 397 |
c.execute("SELECT sequence FROM embeddings")
|
| 398 |
+
return {row[0] for row in c.fetchall()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
|
| 401 |
cursor = conn.cursor()
|
| 402 |
cursor.execute(
|
| 403 |
"CREATE TABLE IF NOT EXISTS embeddings ("
|
| 404 |
"sequence TEXT PRIMARY KEY, "
|
| 405 |
+
"embedding BLOB NOT NULL"
|
|
|
|
|
|
|
| 406 |
")"
|
| 407 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
conn.commit()
|
| 409 |
|
| 410 |
def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
|
|
|
|
| 419 |
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
|
| 420 |
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
|
| 421 |
loaded: Dict[str, torch.Tensor] = {}
|
| 422 |
+
with sqlite3.connect(db_path, timeout=30) as conn:
|
| 423 |
self._ensure_embeddings_table(conn)
|
| 424 |
cursor = conn.cursor()
|
| 425 |
if sequences is None:
|
| 426 |
+
cursor.execute("SELECT sequence, embedding FROM embeddings")
|
| 427 |
else:
|
| 428 |
if len(sequences) == 0:
|
| 429 |
return loaded
|
| 430 |
placeholders = ",".join(["?"] * len(sequences))
|
| 431 |
cursor.execute(
|
| 432 |
+
f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})",
|
| 433 |
tuple(sequences),
|
| 434 |
)
|
| 435 |
|
|
|
|
| 437 |
for row in rows:
|
| 438 |
sequence = row[0]
|
| 439 |
embedding_bytes = row[1]
|
| 440 |
+
loaded[sequence] = embedding_blob_to_tensor(embedding_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
return loaded
|
| 442 |
|
| 443 |
def embed_dataset(
|
|
|
|
| 456 |
sql_db_path: str = 'embeddings.db',
|
| 457 |
save_path: str = 'embeddings.pth',
|
| 458 |
fasta_path: Optional[str] = None,
|
| 459 |
+
padding: str = 'max_length',
|
| 460 |
**kwargs,
|
| 461 |
) -> Optional[Dict[str, torch.Tensor]]:
|
| 462 |
"""
|
|
|
|
| 479 |
hidden_size = self.config.hidden_size
|
| 480 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
| 481 |
tokenizer_mode = tokenizer is not None
|
| 482 |
+
|
| 483 |
+
# Resolve padding and compilation
|
| 484 |
+
dynamic = padding == 'longest'
|
| 485 |
+
compiled_model = maybe_compile(self, dynamic=dynamic)
|
| 486 |
+
|
| 487 |
if tokenizer_mode:
|
| 488 |
+
collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len)
|
| 489 |
device = self.device
|
| 490 |
else:
|
| 491 |
collate_fn = None
|
|
|
|
| 502 |
assert collate_fn is not None
|
| 503 |
assert device is not None
|
| 504 |
dataset = ProteinDataset(to_embed)
|
| 505 |
+
dataloader = DataLoader(
|
| 506 |
+
dataset,
|
| 507 |
+
batch_size=batch_size,
|
| 508 |
+
num_workers=num_workers,
|
| 509 |
+
prefetch_factor=2 if num_workers > 0 else None,
|
| 510 |
+
collate_fn=collate_fn,
|
| 511 |
+
shuffle=False,
|
| 512 |
+
pin_memory=True,
|
| 513 |
+
)
|
| 514 |
+
for i, batch in _make_embedding_progress(dataloader, padding):
|
| 515 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 516 |
input_ids = batch['input_ids'].to(device)
|
| 517 |
attention_mask = batch['attention_mask'].to(device)
|
| 518 |
+
residue_embeddings = compiled_model._embed(input_ids, attention_mask)
|
| 519 |
yield seqs, residue_embeddings, attention_mask
|
| 520 |
else:
|
| 521 |
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 522 |
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 523 |
+
batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs)
|
| 524 |
assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
|
| 525 |
assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
|
| 526 |
residue_embeddings, attention_mask = batch_output
|
|
|
|
| 528 |
yield seqs, residue_embeddings, attention_mask
|
| 529 |
|
| 530 |
if sql:
|
| 531 |
+
conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
|
| 532 |
+
conn.execute('PRAGMA journal_mode=WAL')
|
| 533 |
+
conn.execute('PRAGMA busy_timeout=30000')
|
| 534 |
+
conn.execute('PRAGMA synchronous=OFF')
|
| 535 |
+
conn.execute('PRAGMA cache_size=-64000')
|
| 536 |
self._ensure_embeddings_table(conn)
|
|
|
|
| 537 |
already_embedded = self._read_sequences_from_db(sql_db_path)
|
| 538 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 539 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 540 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 541 |
if len(to_embed) > 0:
|
| 542 |
+
sql_queue: queue.Queue = queue.Queue(maxsize=4)
|
| 543 |
+
|
| 544 |
+
def _sql_writer():
|
| 545 |
+
wc = conn.cursor()
|
| 546 |
+
while True:
|
| 547 |
+
item = sql_queue.get()
|
| 548 |
+
if item is None:
|
| 549 |
+
break
|
| 550 |
+
wc.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
|
| 551 |
+
if sql_queue.qsize() == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
conn.commit()
|
| 553 |
+
conn.commit()
|
| 554 |
+
|
| 555 |
+
sql_writer_thread = threading.Thread(target=_sql_writer, daemon=True)
|
| 556 |
+
sql_writer_thread.start()
|
| 557 |
+
|
| 558 |
+
with torch.inference_mode():
|
| 559 |
+
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
|
| 560 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 561 |
+
if full_embeddings:
|
| 562 |
+
batch_rows = []
|
| 563 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 564 |
+
batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
|
| 565 |
+
else:
|
| 566 |
+
blobs = batch_tensor_to_blobs(embeddings)
|
| 567 |
+
batch_rows = list(zip(seqs, blobs))
|
| 568 |
+
sql_queue.put(batch_rows)
|
| 569 |
+
|
| 570 |
+
sql_queue.put(None)
|
| 571 |
+
sql_writer_thread.join()
|
| 572 |
conn.close()
|
| 573 |
return None
|
| 574 |
|
|
|
|
| 583 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 584 |
|
| 585 |
if len(to_embed) > 0:
|
| 586 |
+
with torch.inference_mode():
|
| 587 |
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
|
| 588 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 589 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|