lhallee commited on
Commit
0045ed7
·
verified ·
1 Parent(s): 0106d87

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +49 -31
modeling_e1.py CHANGED
@@ -42,7 +42,7 @@ from transformers import PreTrainedTokenizerBase
42
 
43
 
44
  # Compact blob serialization constants
45
- # Keep in sync with protify/utils.py and core/atlas/precomputed.py
46
  _COMPACT_VERSION = 0x01
47
  _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
48
  _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
@@ -235,6 +235,40 @@ def _make_embedding_progress(
235
  main_bar.close()
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  class Pooler:
239
  def __init__(self, pooling_types: List[str]) -> None:
240
  self.pooling_types = pooling_types
@@ -528,6 +562,7 @@ class EmbeddingMixin:
528
  yield seqs, residue_embeddings, attention_mask
529
 
530
  if sql:
 
531
  conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
532
  conn.execute('PRAGMA journal_mode=WAL')
533
  conn.execute('PRAGMA busy_timeout=30000')
@@ -539,36 +574,19 @@ class EmbeddingMixin:
539
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
540
  print(f"Embedding {len(to_embed)} new sequences")
541
  if len(to_embed) > 0:
542
- sql_queue: queue.Queue = queue.Queue(maxsize=4)
543
-
544
- def _sql_writer():
545
- wc = conn.cursor()
546
- while True:
547
- item = sql_queue.get()
548
- if item is None:
549
- break
550
- wc.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
551
- if sql_queue.qsize() == 0:
552
- conn.commit()
553
- conn.commit()
554
-
555
- sql_writer_thread = threading.Thread(target=_sql_writer, daemon=True)
556
- sql_writer_thread.start()
557
-
558
- with torch.inference_mode():
559
- for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
560
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
561
- if full_embeddings:
562
- batch_rows = []
563
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
564
- batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
565
- else:
566
- blobs = batch_tensor_to_blobs(embeddings)
567
- batch_rows = list(zip(seqs, blobs))
568
- sql_queue.put(batch_rows)
569
-
570
- sql_queue.put(None)
571
- sql_writer_thread.join()
572
  conn.close()
573
  return None
574
 
 
42
 
43
 
44
  # Compact blob serialization constants
45
+ # Canonical source: core/embed/blob.py. Keep in sync with protify/utils.py.
46
  _COMPACT_VERSION = 0x01
47
  _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
48
  _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
 
235
  main_bar.close()
236
 
237
 
238
+ class _SQLWriter:
239
+ """Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter."""
240
+
241
+ def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None:
242
+ self._conn = conn
243
+ self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize)
244
+ self._thread: Optional[threading.Thread] = None
245
+
246
+ def __enter__(self) -> "_SQLWriter":
247
+ self._thread = threading.Thread(target=self._writer_loop, daemon=True)
248
+ self._thread.start()
249
+ return self
250
+
251
+ def write_batch(self, rows: List[Tuple[str, bytes]]) -> None:
252
+ self._queue.put(rows)
253
+
254
+ def _writer_loop(self) -> None:
255
+ cursor = self._conn.cursor()
256
+ while True:
257
+ item = self._queue.get()
258
+ if item is None:
259
+ break
260
+ cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
261
+ if self._queue.qsize() == 0:
262
+ self._conn.commit()
263
+ self._conn.commit()
264
+
265
+ def __exit__(self, *exc) -> None:
266
+ if self._thread is not None:
267
+ self._queue.put(None)
268
+ self._thread.join()
269
+ self._thread = None
270
+
271
+
272
  class Pooler:
273
  def __init__(self, pooling_types: List[str]) -> None:
274
  self.pooling_types = pooling_types
 
562
  yield seqs, residue_embeddings, attention_mask
563
 
564
  if sql:
565
+ # Step 1: DEDUPLICATE - check existing embeddings in SQL
566
  conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
567
  conn.execute('PRAGMA journal_mode=WAL')
568
  conn.execute('PRAGMA busy_timeout=30000')
 
574
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
575
  print(f"Embedding {len(to_embed)} new sequences")
576
  if len(to_embed) > 0:
577
+ # Steps 4-7: BATCH+EMBED -> POOL/TRIM -> SERIALIZE -> WRITE (async)
578
+ with _SQLWriter(conn) as writer:
579
+ with torch.inference_mode():
580
+ for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
581
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
582
+ if full_embeddings:
583
+ batch_rows = []
584
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
585
+ batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
586
+ else:
587
+ blobs = batch_tensor_to_blobs(embeddings)
588
+ batch_rows = list(zip(seqs, blobs))
589
+ writer.write_batch(batch_rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  conn.close()
591
  return None
592