Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +49 -31
modeling_e1.py
CHANGED
|
@@ -42,7 +42,7 @@ from transformers import PreTrainedTokenizerBase
|
|
| 42 |
|
| 43 |
|
| 44 |
# Compact blob serialization constants
|
| 45 |
-
# Keep in sync with protify/utils.py
|
| 46 |
_COMPACT_VERSION = 0x01
|
| 47 |
_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
|
| 48 |
_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
|
|
@@ -235,6 +235,40 @@ def _make_embedding_progress(
|
|
| 235 |
main_bar.close()
|
| 236 |
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
class Pooler:
|
| 239 |
def __init__(self, pooling_types: List[str]) -> None:
|
| 240 |
self.pooling_types = pooling_types
|
|
@@ -528,6 +562,7 @@ class EmbeddingMixin:
|
|
| 528 |
yield seqs, residue_embeddings, attention_mask
|
| 529 |
|
| 530 |
if sql:
|
|
|
|
| 531 |
conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
|
| 532 |
conn.execute('PRAGMA journal_mode=WAL')
|
| 533 |
conn.execute('PRAGMA busy_timeout=30000')
|
|
@@ -539,36 +574,19 @@ class EmbeddingMixin:
|
|
| 539 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 540 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 541 |
if len(to_embed) > 0:
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
sql_writer_thread = threading.Thread(target=_sql_writer, daemon=True)
|
| 556 |
-
sql_writer_thread.start()
|
| 557 |
-
|
| 558 |
-
with torch.inference_mode():
|
| 559 |
-
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
|
| 560 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 561 |
-
if full_embeddings:
|
| 562 |
-
batch_rows = []
|
| 563 |
-
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 564 |
-
batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
|
| 565 |
-
else:
|
| 566 |
-
blobs = batch_tensor_to_blobs(embeddings)
|
| 567 |
-
batch_rows = list(zip(seqs, blobs))
|
| 568 |
-
sql_queue.put(batch_rows)
|
| 569 |
-
|
| 570 |
-
sql_queue.put(None)
|
| 571 |
-
sql_writer_thread.join()
|
| 572 |
conn.close()
|
| 573 |
return None
|
| 574 |
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
# Compact blob serialization constants
|
| 45 |
+
# Canonical source: core/embed/blob.py. Keep in sync with protify/utils.py.
|
| 46 |
_COMPACT_VERSION = 0x01
|
| 47 |
_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
|
| 48 |
_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
|
|
|
|
| 235 |
main_bar.close()
|
| 236 |
|
| 237 |
|
| 238 |
+
class _SQLWriter:
|
| 239 |
+
"""Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter."""
|
| 240 |
+
|
| 241 |
+
def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None:
|
| 242 |
+
self._conn = conn
|
| 243 |
+
self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize)
|
| 244 |
+
self._thread: Optional[threading.Thread] = None
|
| 245 |
+
|
| 246 |
+
def __enter__(self) -> "_SQLWriter":
|
| 247 |
+
self._thread = threading.Thread(target=self._writer_loop, daemon=True)
|
| 248 |
+
self._thread.start()
|
| 249 |
+
return self
|
| 250 |
+
|
| 251 |
+
def write_batch(self, rows: List[Tuple[str, bytes]]) -> None:
|
| 252 |
+
self._queue.put(rows)
|
| 253 |
+
|
| 254 |
+
def _writer_loop(self) -> None:
|
| 255 |
+
cursor = self._conn.cursor()
|
| 256 |
+
while True:
|
| 257 |
+
item = self._queue.get()
|
| 258 |
+
if item is None:
|
| 259 |
+
break
|
| 260 |
+
cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
|
| 261 |
+
if self._queue.qsize() == 0:
|
| 262 |
+
self._conn.commit()
|
| 263 |
+
self._conn.commit()
|
| 264 |
+
|
| 265 |
+
def __exit__(self, *exc) -> None:
|
| 266 |
+
if self._thread is not None:
|
| 267 |
+
self._queue.put(None)
|
| 268 |
+
self._thread.join()
|
| 269 |
+
self._thread = None
|
| 270 |
+
|
| 271 |
+
|
| 272 |
class Pooler:
|
| 273 |
def __init__(self, pooling_types: List[str]) -> None:
|
| 274 |
self.pooling_types = pooling_types
|
|
|
|
| 562 |
yield seqs, residue_embeddings, attention_mask
|
| 563 |
|
| 564 |
if sql:
|
| 565 |
+
# Step 1: DEDUPLICATE - check existing embeddings in SQL
|
| 566 |
conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
|
| 567 |
conn.execute('PRAGMA journal_mode=WAL')
|
| 568 |
conn.execute('PRAGMA busy_timeout=30000')
|
|
|
|
| 574 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 575 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 576 |
if len(to_embed) > 0:
|
| 577 |
+
# Steps 4-7: BATCH+EMBED -> POOL/TRIM -> SERIALIZE -> WRITE (async)
|
| 578 |
+
with _SQLWriter(conn) as writer:
|
| 579 |
+
with torch.inference_mode():
|
| 580 |
+
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
|
| 581 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 582 |
+
if full_embeddings:
|
| 583 |
+
batch_rows = []
|
| 584 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 585 |
+
batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
|
| 586 |
+
else:
|
| 587 |
+
blobs = batch_tensor_to_blobs(embeddings)
|
| 588 |
+
batch_rows = list(zip(seqs, blobs))
|
| 589 |
+
writer.write_batch(batch_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
conn.close()
|
| 591 |
return None
|
| 592 |
|