lhallee commited on
Commit
dd707be
·
verified ·
1 Parent(s): 296ff4a

Upload embedding_mixin.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. embedding_mixin.py +74 -6
embedding_mixin.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import networkx as nx
3
  import numpy as np
4
  import torch
@@ -165,7 +166,6 @@ class EmbeddingMixin:
165
 
166
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
167
  """Read sequences from SQLite database."""
168
- import sqlite3
169
  sequences = []
170
  with sqlite3.connect(db_path) as conn:
171
  c = conn.cursor()
@@ -177,6 +177,69 @@ class EmbeddingMixin:
177
  sequences.append(row[0])
178
  return set(sequences)
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def embed_dataset(
181
  self,
182
  sequences: List[str],
@@ -241,10 +304,9 @@ class EmbeddingMixin:
241
  yield seqs, residue_embeddings, attention_mask
242
 
243
  if sql:
244
- import sqlite3
245
  conn = sqlite3.connect(sql_db_path)
 
246
  c = conn.cursor()
247
- c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
248
  already_embedded = self._read_sequences_from_db(sql_db_path)
249
  to_embed = [seq for seq in sequences if seq not in already_embedded]
250
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
@@ -252,11 +314,17 @@ class EmbeddingMixin:
252
  if len(to_embed) > 0:
253
  with torch.no_grad():
254
  for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
255
- embeddings = get_embeddings(residue_embeddings, attention_mask).float()
256
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
257
  if full_embeddings:
258
  emb = emb[mask.bool()].reshape(-1, hidden_size)
259
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
 
 
 
 
 
 
260
  if tokenizer_mode and (i + 1) % 100 == 0:
261
  conn.commit()
262
  conn.commit()
@@ -265,7 +333,7 @@ class EmbeddingMixin:
265
 
266
  embeddings_dict = {}
267
  if os.path.exists(save_path):
268
- embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
269
  to_embed = [seq for seq in sequences if seq not in embeddings_dict]
270
  print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
271
  print(f"Embedding {len(to_embed)} new sequences")
 
1
  import os
2
+ import sqlite3
3
  import networkx as nx
4
  import numpy as np
5
  import torch
 
166
 
167
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
168
  """Read sequences from SQLite database."""
 
169
  sequences = []
170
  with sqlite3.connect(db_path) as conn:
171
  c = conn.cursor()
 
177
  sequences.append(row[0])
178
  return set(sequences)
179
 
180
+ def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
181
+ cursor = conn.cursor()
182
+ cursor.execute(
183
+ "CREATE TABLE IF NOT EXISTS embeddings ("
184
+ "sequence TEXT PRIMARY KEY, "
185
+ "embedding BLOB NOT NULL, "
186
+ "shape TEXT, "
187
+ "dtype TEXT"
188
+ ")"
189
+ )
190
+ cursor.execute("PRAGMA table_info(embeddings)")
191
+ rows = cursor.fetchall()
192
+ column_names = [row[1] for row in rows]
193
+ if "shape" not in column_names:
194
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
195
+ if "dtype" not in column_names:
196
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
197
+ conn.commit()
198
+
199
+ def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]:
200
+ assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
201
+ payload = torch.load(save_path, map_location="cpu", weights_only=True)
202
+ assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
203
+ for sequence, tensor in payload.items():
204
+ assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
205
+ assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
206
+ return payload
207
+
208
+ def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]:
209
+ assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
210
+ loaded: dict[str, torch.Tensor] = {}
211
+ with sqlite3.connect(db_path) as conn:
212
+ self._ensure_embeddings_table(conn)
213
+ cursor = conn.cursor()
214
+ if sequences is None:
215
+ cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
216
+ else:
217
+ if len(sequences) == 0:
218
+ return loaded
219
+ placeholders = ",".join(["?"] * len(sequences))
220
+ cursor.execute(
221
+ f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
222
+ tuple(sequences),
223
+ )
224
+
225
+ rows = cursor.fetchall()
226
+ for row in rows:
227
+ sequence = row[0]
228
+ embedding_bytes = row[1]
229
+ shape_text = row[2]
230
+ dtype_text = row[3]
231
+ assert shape_text is not None, "Missing shape metadata in embeddings table."
232
+ assert dtype_text is not None, "Missing dtype metadata in embeddings table."
233
+ shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
234
+ assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
235
+ expected_size = int(np.prod(shape_values))
236
+ np_dtype = np.dtype(dtype_text)
237
+ array = np.frombuffer(embedding_bytes, dtype=np_dtype)
238
+ assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
239
+ reshaped = array.copy().reshape(tuple(shape_values))
240
+ loaded[sequence] = torch.from_numpy(reshaped)
241
+ return loaded
242
+
243
  def embed_dataset(
244
  self,
245
  sequences: List[str],
 
304
  yield seqs, residue_embeddings, attention_mask
305
 
306
  if sql:
 
307
  conn = sqlite3.connect(sql_db_path)
308
+ self._ensure_embeddings_table(conn)
309
  c = conn.cursor()
 
310
  already_embedded = self._read_sequences_from_db(sql_db_path)
311
  to_embed = [seq for seq in sequences if seq not in already_embedded]
312
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
 
314
  if len(to_embed) > 0:
315
  with torch.no_grad():
316
  for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
317
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
318
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
319
  if full_embeddings:
320
  emb = emb[mask.bool()].reshape(-1, hidden_size)
321
+ emb_np = emb.cpu().numpy()
322
+ emb_shape = ",".join([str(dim) for dim in emb_np.shape])
323
+ emb_dtype = str(emb_np.dtype)
324
+ c.execute(
325
+ "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
326
+ (seq, emb_np.tobytes(), emb_shape, emb_dtype),
327
+ )
328
  if tokenizer_mode and (i + 1) % 100 == 0:
329
  conn.commit()
330
  conn.commit()
 
333
 
334
  embeddings_dict = {}
335
  if os.path.exists(save_path):
336
+ embeddings_dict = self.load_embeddings_from_pth(save_path)
337
  to_embed = [seq for seq in sequences if seq not in embeddings_dict]
338
  print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
339
  print(f"Embedding {len(to_embed)} new sequences")