lhallee commited on
Commit
0106d87
·
verified ·
1 Parent(s): 61ef43f

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +338 -119
modeling_e1.py CHANGED
@@ -23,18 +23,218 @@ inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
23
  dynamo.config.capture_scalar_outputs = True
24
  torch._dynamo.config.recompile_limit = 16
25
 
 
26
  import os
 
27
  import sqlite3
 
 
 
 
28
  import networkx as nx
29
  import numpy as np
30
  import torch
31
  from tqdm.auto import tqdm
32
- from typing import Callable, Dict, List, Optional, Set
33
  from torch.utils.data import DataLoader
34
  from torch.utils.data import Dataset as TorchDataset
35
  from transformers import PreTrainedTokenizerBase
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class Pooler:
39
  def __init__(self, pooling_types: List[str]) -> None:
40
  self.pooling_types = pooling_types
@@ -55,9 +255,6 @@ class Pooler:
55
  return maxed_attentions
56
 
57
  def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
58
- # Run PageRank on the attention matrix converted to a graph.
59
- # Raises exceptions if the graph doesn't match the token sequence or has no edges.
60
- # Returns the PageRank scores for each token node.
61
  G = self._convert_to_graph(attention_matrix)
62
  if G.number_of_nodes() != attention_matrix.shape[0]:
63
  raise Exception(
@@ -68,26 +265,20 @@ class Pooler:
68
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
69
 
70
  def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
71
- # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
72
- # Each element in the matrix represents a directed edge with a weight.
73
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
74
  return G
75
 
76
  def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
77
- # Remove keys where attention_mask is 0
78
  if attention_mask is not None:
79
  for k in list(dict_importance.keys()):
80
  if attention_mask[k] == 0:
81
  del dict_importance[k]
82
 
83
- #dict_importance[0] # remove cls
84
- #dict_importance[-1] # remove eos
85
  total = sum(dict_importance.values())
86
  return np.array([v / total for _, v in dict_importance.items()])
87
 
88
- def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
89
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
90
- # emb is (b, L, d), maxed_attentions is (b, L, L)
91
  emb_pooled = []
92
  for e, a, mask in zip(emb, maxed_attentions, attention_mask):
93
  dict_importance = self._page_rank(a)
@@ -97,58 +288,53 @@ class Pooler:
97
  pooled = torch.tensor(np.array(emb_pooled))
98
  return pooled
99
 
100
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
101
  if attention_mask is None:
102
  return emb.mean(dim=1)
103
  else:
104
  attention_mask = attention_mask.unsqueeze(-1)
105
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
106
 
107
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
108
  if attention_mask is None:
109
  return emb.max(dim=1).values
110
  else:
111
  mask = attention_mask.unsqueeze(-1).bool()
112
  return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
113
 
114
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
115
  if attention_mask is None:
116
  return emb.norm(dim=1, p=2)
117
  else:
118
  attention_mask = attention_mask.unsqueeze(-1)
119
  return (emb * attention_mask).norm(dim=1, p=2)
120
 
121
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
122
  if attention_mask is None:
123
  return emb.median(dim=1).values
124
  else:
125
  mask = attention_mask.unsqueeze(-1).bool()
126
  return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
127
-
128
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
129
  if attention_mask is None:
130
  return emb.std(dim=1)
131
  else:
132
- # Compute variance correctly over non-masked positions, then take sqrt
133
  var = self.var_pooling(emb, attention_mask, **kwargs)
134
  return torch.sqrt(var)
135
-
136
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
137
  if attention_mask is None:
138
  return emb.var(dim=1)
139
  else:
140
- # Correctly compute variance over only non-masked positions
141
- attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
142
- # Compute mean over non-masked positions
143
- mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
144
- mean = mean.unsqueeze(1) # (b, 1, d)
145
- # Compute squared differences from mean, only over non-masked positions
146
- squared_diff = (emb - mean) ** 2 # (b, L, d)
147
- # Sum squared differences over non-masked positions and divide by count
148
- var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
149
  return var
150
 
151
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
152
  return emb[:, 0, :]
153
 
154
  def __call__(
@@ -156,11 +342,11 @@ class Pooler:
156
  emb: torch.Tensor,
157
  attention_mask: Optional[torch.Tensor] = None,
158
  attentions: Optional[torch.Tensor] = None
159
- ) -> torch.Tensor: # [mean, max]
160
  final_emb: List[torch.Tensor] = []
161
  for pooling_type in self.pooling_types:
162
- final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
163
- return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
164
 
165
 
166
  class ProteinDataset(TorchDataset):
@@ -175,12 +361,6 @@ class ProteinDataset(TorchDataset):
175
  return self.sequences[idx]
176
 
177
 
178
- def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
179
- def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
180
- return tokenizer(sequences, return_tensors="pt", padding='longest')
181
- return _collate_fn
182
-
183
-
184
  def parse_fasta(fasta_path: str) -> List[str]:
185
  assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
186
  sequences = []
@@ -212,34 +392,19 @@ class EmbeddingMixin:
212
 
213
  def _read_sequences_from_db(self, db_path: str) -> Set[str]:
214
  """Read sequences from SQLite database."""
215
- sequences = []
216
- with sqlite3.connect(db_path) as conn:
217
  c = conn.cursor()
218
  c.execute("SELECT sequence FROM embeddings")
219
- while True:
220
- row = c.fetchone()
221
- if row is None:
222
- break
223
- sequences.append(row[0])
224
- return set(sequences)
225
 
226
  def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
227
  cursor = conn.cursor()
228
  cursor.execute(
229
  "CREATE TABLE IF NOT EXISTS embeddings ("
230
  "sequence TEXT PRIMARY KEY, "
231
- "embedding BLOB NOT NULL, "
232
- "shape TEXT, "
233
- "dtype TEXT"
234
  ")"
235
  )
236
- cursor.execute("PRAGMA table_info(embeddings)")
237
- rows = cursor.fetchall()
238
- column_names = [row[1] for row in rows]
239
- if "shape" not in column_names:
240
- cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
241
- if "dtype" not in column_names:
242
- cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
243
  conn.commit()
244
 
245
  def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
@@ -254,17 +419,17 @@ class EmbeddingMixin:
254
  def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
255
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
256
  loaded: Dict[str, torch.Tensor] = {}
257
- with sqlite3.connect(db_path) as conn:
258
  self._ensure_embeddings_table(conn)
259
  cursor = conn.cursor()
260
  if sequences is None:
261
- cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
262
  else:
263
  if len(sequences) == 0:
264
  return loaded
265
  placeholders = ",".join(["?"] * len(sequences))
266
  cursor.execute(
267
- f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
268
  tuple(sequences),
269
  )
270
 
@@ -272,18 +437,7 @@ class EmbeddingMixin:
272
  for row in rows:
273
  sequence = row[0]
274
  embedding_bytes = row[1]
275
- shape_text = row[2]
276
- dtype_text = row[3]
277
- assert shape_text is not None, "Missing shape metadata in embeddings table."
278
- assert dtype_text is not None, "Missing dtype metadata in embeddings table."
279
- shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
280
- assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
281
- expected_size = int(np.prod(shape_values))
282
- np_dtype = np.dtype(dtype_text)
283
- array = np.frombuffer(embedding_bytes, dtype=np_dtype)
284
- assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
285
- reshaped = array.copy().reshape(tuple(shape_values))
286
- loaded[sequence] = torch.from_numpy(reshaped)
287
  return loaded
288
 
289
  def embed_dataset(
@@ -302,6 +456,7 @@ class EmbeddingMixin:
302
  sql_db_path: str = 'embeddings.db',
303
  save_path: str = 'embeddings.pth',
304
  fasta_path: Optional[str] = None,
 
305
  **kwargs,
306
  ) -> Optional[Dict[str, torch.Tensor]]:
307
  """
@@ -324,8 +479,13 @@ class EmbeddingMixin:
324
  hidden_size = self.config.hidden_size
325
  pooler = Pooler(pooling_types) if not full_embeddings else None
326
  tokenizer_mode = tokenizer is not None
 
 
 
 
 
327
  if tokenizer_mode:
328
- collate_fn = build_collator(tokenizer)
329
  device = self.device
330
  else:
331
  collate_fn = None
@@ -342,17 +502,25 @@ class EmbeddingMixin:
342
  assert collate_fn is not None
343
  assert device is not None
344
  dataset = ProteinDataset(to_embed)
345
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
346
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
 
 
 
 
 
 
 
 
347
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
348
  input_ids = batch['input_ids'].to(device)
349
  attention_mask = batch['attention_mask'].to(device)
350
- residue_embeddings = self._embed(input_ids, attention_mask)
351
  yield seqs, residue_embeddings, attention_mask
352
  else:
353
  for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
354
  seqs = to_embed[batch_start:batch_start + batch_size]
355
- batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
356
  assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
357
  assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
358
  residue_embeddings, attention_mask = batch_output
@@ -360,30 +528,47 @@ class EmbeddingMixin:
360
  yield seqs, residue_embeddings, attention_mask
361
 
362
  if sql:
363
- conn = sqlite3.connect(sql_db_path)
 
 
 
 
364
  self._ensure_embeddings_table(conn)
365
- c = conn.cursor()
366
  already_embedded = self._read_sequences_from_db(sql_db_path)
367
  to_embed = [seq for seq in sequences if seq not in already_embedded]
368
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
369
  print(f"Embedding {len(to_embed)} new sequences")
370
  if len(to_embed) > 0:
371
- with torch.no_grad():
372
- for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
373
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
374
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
375
- if full_embeddings:
376
- emb = emb[mask.bool()].reshape(-1, hidden_size)
377
- emb_np = emb.cpu().numpy()
378
- emb_shape = ",".join([str(dim) for dim in emb_np.shape])
379
- emb_dtype = str(emb_np.dtype)
380
- c.execute(
381
- "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
382
- (seq, emb_np.tobytes(), emb_shape, emb_dtype),
383
- )
384
- if tokenizer_mode and (i + 1) % 100 == 0:
385
  conn.commit()
386
- conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  conn.close()
388
  return None
389
 
@@ -398,7 +583,7 @@ class EmbeddingMixin:
398
  print(f"Embedding {len(to_embed)} new sequences")
399
 
400
  if len(to_embed) > 0:
401
- with torch.no_grad():
402
  for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
403
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
404
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
@@ -2208,10 +2393,11 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2208
  # Ignore copy
2209
  def forward(
2210
  self,
2211
- input_ids: torch.LongTensor,
2212
- within_seq_position_ids: torch.LongTensor,
2213
- global_position_ids: torch.LongTensor,
2214
- sequence_ids: torch.LongTensor,
 
2215
  past_key_values: Optional[DynamicCache] = None,
2216
  use_cache: bool = False,
2217
  output_attentions: bool = False,
@@ -2234,6 +2420,9 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2234
  This tensor contains the sequence id of each residue.
2235
  For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
2236
  the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
 
 
 
2237
  past_key_values: DynamicCache
2238
  use_cache: bool
2239
  output_attentions: bool
@@ -2243,7 +2432,17 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2243
  Returns:
2244
  E1ModelOutputWithPast: Model Outputs
2245
  """
2246
- batch_size, seq_length = input_ids.shape
 
 
 
 
 
 
 
 
 
 
2247
 
2248
  if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
2249
  if use_cache:
@@ -2257,6 +2456,16 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2257
  elif not use_cache:
2258
  past_key_values = None
2259
 
 
 
 
 
 
 
 
 
 
 
2260
  global_position_ids = global_position_ids.view(-1, seq_length).long()
2261
  within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long()
2262
  sequence_ids = sequence_ids.view(-1, seq_length).long()
@@ -2267,8 +2476,9 @@ class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin):
2267
  f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}"
2268
  )
2269
 
2270
- inputs_embeds = self.embed_tokens(input_ids)
2271
- inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0))
 
2272
 
2273
  if torch.is_autocast_enabled():
2274
  target_dtype = torch.get_autocast_gpu_dtype()
@@ -2380,10 +2590,11 @@ class E1Model(E1PreTrainedModel, EmbeddingMixin):
2380
 
2381
  def forward(
2382
  self,
2383
- input_ids: torch.LongTensor,
2384
- within_seq_position_ids: torch.LongTensor,
2385
- global_position_ids: torch.LongTensor,
2386
- sequence_ids: torch.LongTensor,
 
2387
  past_key_values: Optional[DynamicCache] = None,
2388
  use_cache: bool = False,
2389
  output_attentions: bool = False,
@@ -2396,6 +2607,7 @@ class E1Model(E1PreTrainedModel, EmbeddingMixin):
2396
  within_seq_position_ids=within_seq_position_ids,
2397
  global_position_ids=global_position_ids,
2398
  sequence_ids=sequence_ids,
 
2399
  past_key_values=past_key_values,
2400
  use_cache=use_cache,
2401
  output_attentions=output_attentions,
@@ -2438,10 +2650,11 @@ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
2438
 
2439
  def forward(
2440
  self,
2441
- input_ids: torch.LongTensor,
2442
- within_seq_position_ids: torch.LongTensor,
2443
- global_position_ids: torch.LongTensor,
2444
- sequence_ids: torch.LongTensor,
 
2445
  labels: Optional[torch.LongTensor] = None,
2446
  past_key_values: Optional[DynamicCache] = None,
2447
  use_cache: bool = False,
@@ -2465,6 +2678,7 @@ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
2465
  This tensor contains the sequence id of each residue.
2466
  For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
2467
  the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
 
2468
  labels: (batch_size, seq_length)
2469
  past_key_values: DynamicCache
2470
  use_cache: bool
@@ -2480,6 +2694,7 @@ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
2480
  within_seq_position_ids=within_seq_position_ids,
2481
  global_position_ids=global_position_ids,
2482
  sequence_ids=sequence_ids,
 
2483
  past_key_values=past_key_values,
2484
  use_cache=use_cache,
2485
  output_attentions=output_attentions,
@@ -2557,10 +2772,11 @@ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
2557
 
2558
  def forward(
2559
  self,
2560
- input_ids: torch.LongTensor,
2561
- within_seq_position_ids: torch.LongTensor,
2562
- global_position_ids: torch.LongTensor,
2563
- sequence_ids: torch.LongTensor,
 
2564
  labels: Optional[torch.LongTensor] = None,
2565
  past_key_values: Optional[DynamicCache] = None,
2566
  use_cache: bool = False,
@@ -2574,6 +2790,7 @@ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
2574
  within_seq_position_ids=within_seq_position_ids,
2575
  global_position_ids=global_position_ids,
2576
  sequence_ids=sequence_ids,
 
2577
  past_key_values=past_key_values,
2578
  use_cache=use_cache,
2579
  output_attentions=output_attentions,
@@ -2581,7 +2798,7 @@ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
2581
  output_s_max=output_s_max,
2582
  )
2583
 
2584
- attention_mask = (sequence_ids != -1).long()
2585
  x = outputs.last_hidden_state
2586
  features = self.pooler(x, attention_mask)
2587
  logits = self.classifier(features)
@@ -2652,10 +2869,11 @@ class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
2652
 
2653
  def forward(
2654
  self,
2655
- input_ids: torch.LongTensor,
2656
- within_seq_position_ids: torch.LongTensor,
2657
- global_position_ids: torch.LongTensor,
2658
- sequence_ids: torch.LongTensor,
 
2659
  labels: Optional[torch.LongTensor] = None,
2660
  past_key_values: Optional[DynamicCache] = None,
2661
  use_cache: bool = False,
@@ -2669,6 +2887,7 @@ class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
2669
  within_seq_position_ids=within_seq_position_ids,
2670
  global_position_ids=global_position_ids,
2671
  sequence_ids=sequence_ids,
 
2672
  past_key_values=past_key_values,
2673
  use_cache=use_cache,
2674
  output_attentions=output_attentions,
 
23
  dynamo.config.capture_scalar_outputs = True
24
  torch._dynamo.config.recompile_limit = 16
25
 
26
+ import io
27
  import os
28
+ import queue
29
  import sqlite3
30
+ import struct
31
+ import threading
32
+ import time
33
+
34
  import networkx as nx
35
  import numpy as np
36
  import torch
37
  from tqdm.auto import tqdm
38
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
39
  from torch.utils.data import DataLoader
40
  from torch.utils.data import Dataset as TorchDataset
41
  from transformers import PreTrainedTokenizerBase
42
 
43
 
44
+ # Compact blob serialization constants
45
+ # Keep in sync with protify/utils.py and core/atlas/precomputed.py
46
+ _COMPACT_VERSION = 0x01
47
+ _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
48
+ _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
49
+ _CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32}
50
+
51
+
52
+ def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes:
53
+ """Serialize a tensor to compact binary format for SQLite blob storage.
54
+
55
+ Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes]
56
+ bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16)
57
+ but tagged with dtype_code=1 so they can be cast back on read.
58
+ Falls back to torch.save for unsupported dtypes.
59
+ """
60
+ t = tensor.cpu()
61
+ if t.dtype not in _DTYPE_TO_CODE:
62
+ buffer = io.BytesIO()
63
+ torch.save(t, buffer)
64
+ return buffer.getvalue()
65
+ dtype_code = _DTYPE_TO_CODE[t.dtype]
66
+
67
+ if t.dtype == torch.bfloat16:
68
+ raw = t.half().numpy().tobytes()
69
+ else:
70
+ raw = t.numpy().tobytes()
71
+
72
+ shape = t.shape
73
+ header = struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
74
+ return header + raw
75
+
76
+
77
+ def _compact_header(dtype: torch.dtype, shape: tuple) -> bytes:
78
+ """Build just the compact header for a given dtype and shape."""
79
+ dtype_code = _DTYPE_TO_CODE[dtype]
80
+ return struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
81
+
82
+
83
+ def batch_tensor_to_blobs(batch: torch.Tensor) -> List[bytes]:
84
+ """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors).
85
+
86
+ Builds the header once and slices raw bytes per row. Much faster than
87
+ per-row tensor_to_embedding_blob calls for uniform-shape batches.
88
+ """
89
+ assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}"
90
+ t = batch.cpu()
91
+ store_dtype = t.dtype
92
+ if t.dtype not in _DTYPE_TO_CODE:
93
+ return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])]
94
+
95
+ if t.dtype == torch.bfloat16:
96
+ arr = t.half().numpy()
97
+ store_dtype = torch.bfloat16
98
+ else:
99
+ arr = t.numpy()
100
+
101
+ row_shape = tuple(t.shape[1:])
102
+ header = _compact_header(store_dtype, row_shape)
103
+ raw = arr.tobytes()
104
+ stride = len(raw) // t.shape[0]
105
+ return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])]
106
+
107
+
108
+ def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
109
+ """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats."""
110
+ if len(blob) >= 6 and blob[0] == _COMPACT_VERSION:
111
+ dtype_code = blob[1]
112
+ ndim = struct.unpack_from('<i', blob, 2)[0]
113
+ shape = struct.unpack_from(f'<{ndim}i', blob, 6)
114
+ data_offset = 6 + 4 * ndim
115
+ np_dtype = _CODE_TO_NP_DTYPE[dtype_code]
116
+ arr = np.frombuffer(blob, dtype=np_dtype, offset=data_offset).copy().reshape(shape)
117
+ t = torch.from_numpy(arr)
118
+ target_dtype = _CODE_TO_DTYPE[dtype_code]
119
+ if target_dtype != t.dtype:
120
+ t = t.to(target_dtype)
121
+ return t
122
+
123
+ # Fallback: try torch.load (pickle format)
124
+ try:
125
+ buffer = io.BytesIO(blob)
126
+ return torch.load(buffer, map_location='cpu', weights_only=True)
127
+ except Exception:
128
+ pass
129
+
130
+ # Legacy fallback: raw float32 bytes with caller-supplied shape
131
+ assert fallback_shape is not None, "Cannot deserialize blob: unknown format and no fallback_shape provided."
132
+ arr = np.frombuffer(blob, dtype=np.float32).copy().reshape(fallback_shape)
133
+ return torch.from_numpy(arr)
134
+
135
+
136
+ def maybe_compile(model: torch.nn.Module, dynamic: bool = False) -> torch.nn.Module:
137
+ """Compile model with torch.compile if possible.
138
+
139
+ Skips compilation when dynamic=True (padding='longest') because
140
+ flex attention's create_block_mask is incompatible with dynamic shapes
141
+ under torch.compile, causing CUDA illegal memory access.
142
+ """
143
+ if dynamic:
144
+ print("Skipping torch.compile (dynamic shapes + flex attention incompatible)")
145
+ return model
146
+ try:
147
+ model = torch.compile(model)
148
+ print("Model compiled")
149
+ except Exception as e:
150
+ print(f"Skipping torch.compile: {e}")
151
+ return model
152
+
153
+
154
+ def build_collator(
155
+ tokenizer: PreTrainedTokenizerBase,
156
+ padding: str = 'max_length',
157
+ max_length: int = 512,
158
+ ) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
159
+ def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
160
+ kwargs: Dict[str, Any] = dict(
161
+ return_tensors="pt", padding=padding, truncation=True, max_length=max_length,
162
+ )
163
+ if padding != 'max_length':
164
+ kwargs['pad_to_multiple_of'] = 8
165
+ return tokenizer(sequences, **kwargs)
166
+ return _collate_fn
167
+
168
+
169
+ def _make_embedding_progress(
170
+ dataloader: DataLoader,
171
+ padding: str,
172
+ n_warmup: int = 3,
173
+ n_calibration: int = 5,
174
+ ) -> Iterator[Tuple[int, Any]]:
175
+ """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader).
176
+
177
+ When padding='max_length', all batches have uniform cost so plain tqdm works.
178
+ When padding='longest' (sorted longest-first), batch times vary dramatically.
179
+ In that case: yield warmup batches first (compiler warmup + OOM check on longest
180
+ sequences), then time mid-length calibration batches to estimate total ETA.
181
+
182
+ Keep in sync with protify/embedder.py and core/atlas/precomputed.py.
183
+ """
184
+ total = len(dataloader)
185
+ if padding == 'max_length' or total <= n_warmup + n_calibration:
186
+ for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'):
187
+ yield i, batch
188
+ return
189
+
190
+ dl_iter = iter(dataloader)
191
+
192
+ # Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first)
193
+ warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False)
194
+ for i in warmup_bar:
195
+ batch = next(dl_iter)
196
+ yield i, batch
197
+ warmup_bar.close()
198
+
199
+ # Phase 2: skip to middle of dataset for calibration timing
200
+ # We need to yield all intermediate batches too (they contain real data)
201
+ mid_start = total // 2
202
+ intermediate_bar = tqdm(
203
+ range(n_warmup, mid_start), desc='Embedding batches', leave=False,
204
+ )
205
+ for i in intermediate_bar:
206
+ batch = next(dl_iter)
207
+ yield i, batch
208
+ intermediate_bar.close()
209
+
210
+ # Phase 3: time calibration batches from the middle
211
+ calibration_times: List[float] = []
212
+ cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False)
213
+ for j in cal_bar:
214
+ t0 = time.perf_counter()
215
+ batch = next(dl_iter)
216
+ yield mid_start + j, batch
217
+ calibration_times.append(time.perf_counter() - t0)
218
+ cal_bar.close()
219
+
220
+ avg_time = sum(calibration_times) / len(calibration_times)
221
+ remaining_start = mid_start + n_calibration
222
+ remaining_count = total - remaining_start
223
+ estimated_total_seconds = avg_time * remaining_count
224
+
225
+ # Phase 4: remaining batches with calibrated ETA
226
+ main_bar = tqdm(
227
+ range(remaining_count),
228
+ desc='Embedding batches',
229
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
230
+ )
231
+ main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)')
232
+ for k in main_bar:
233
+ batch = next(dl_iter)
234
+ yield remaining_start + k, batch
235
+ main_bar.close()
236
+
237
+
238
  class Pooler:
239
  def __init__(self, pooling_types: List[str]) -> None:
240
  self.pooling_types = pooling_types
 
255
  return maxed_attentions
256
 
257
  def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
 
 
 
258
  G = self._convert_to_graph(attention_matrix)
259
  if G.number_of_nodes() != attention_matrix.shape[0]:
260
  raise Exception(
 
265
  return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
266
 
267
  def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
 
 
268
  G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
269
  return G
270
 
271
  def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
 
272
  if attention_mask is not None:
273
  for k in list(dict_importance.keys()):
274
  if attention_mask[k] == 0:
275
  del dict_importance[k]
276
 
 
 
277
  total = sum(dict_importance.values())
278
  return np.array([v / total for _, v in dict_importance.items()])
279
 
280
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
281
  maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
 
282
  emb_pooled = []
283
  for e, a, mask in zip(emb, maxed_attentions, attention_mask):
284
  dict_importance = self._page_rank(a)
 
288
  pooled = torch.tensor(np.array(emb_pooled))
289
  return pooled
290
 
291
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
292
  if attention_mask is None:
293
  return emb.mean(dim=1)
294
  else:
295
  attention_mask = attention_mask.unsqueeze(-1)
296
  return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
297
 
298
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
299
  if attention_mask is None:
300
  return emb.max(dim=1).values
301
  else:
302
  mask = attention_mask.unsqueeze(-1).bool()
303
  return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
304
 
305
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
306
  if attention_mask is None:
307
  return emb.norm(dim=1, p=2)
308
  else:
309
  attention_mask = attention_mask.unsqueeze(-1)
310
  return (emb * attention_mask).norm(dim=1, p=2)
311
 
312
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
313
  if attention_mask is None:
314
  return emb.median(dim=1).values
315
  else:
316
  mask = attention_mask.unsqueeze(-1).bool()
317
  return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
318
+
319
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
320
  if attention_mask is None:
321
  return emb.std(dim=1)
322
  else:
 
323
  var = self.var_pooling(emb, attention_mask, **kwargs)
324
  return torch.sqrt(var)
325
+
326
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
327
  if attention_mask is None:
328
  return emb.var(dim=1)
329
  else:
330
+ attention_mask = attention_mask.unsqueeze(-1)
331
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
332
+ mean = mean.unsqueeze(1)
333
+ squared_diff = (emb - mean) ** 2
334
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
 
 
 
 
335
  return var
336
 
337
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
338
  return emb[:, 0, :]
339
 
340
  def __call__(
 
342
  emb: torch.Tensor,
343
  attention_mask: Optional[torch.Tensor] = None,
344
  attentions: Optional[torch.Tensor] = None
345
+ ) -> torch.Tensor:
346
  final_emb: List[torch.Tensor] = []
347
  for pooling_type in self.pooling_types:
348
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
349
+ return torch.cat(final_emb, dim=-1)
350
 
351
 
352
  class ProteinDataset(TorchDataset):
 
361
  return self.sequences[idx]
362
 
363
 
 
 
 
 
 
 
364
  def parse_fasta(fasta_path: str) -> List[str]:
365
  assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
366
  sequences = []
 
392
 
393
  def _read_sequences_from_db(self, db_path: str) -> Set[str]:
394
  """Read sequences from SQLite database."""
395
+ with sqlite3.connect(db_path, timeout=30) as conn:
 
396
  c = conn.cursor()
397
  c.execute("SELECT sequence FROM embeddings")
398
+ return {row[0] for row in c.fetchall()}
 
 
 
 
 
399
 
400
  def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
401
  cursor = conn.cursor()
402
  cursor.execute(
403
  "CREATE TABLE IF NOT EXISTS embeddings ("
404
  "sequence TEXT PRIMARY KEY, "
405
+ "embedding BLOB NOT NULL"
 
 
406
  ")"
407
  )
 
 
 
 
 
 
 
408
  conn.commit()
409
 
410
  def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
 
419
  def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
420
  assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
421
  loaded: Dict[str, torch.Tensor] = {}
422
+ with sqlite3.connect(db_path, timeout=30) as conn:
423
  self._ensure_embeddings_table(conn)
424
  cursor = conn.cursor()
425
  if sequences is None:
426
+ cursor.execute("SELECT sequence, embedding FROM embeddings")
427
  else:
428
  if len(sequences) == 0:
429
  return loaded
430
  placeholders = ",".join(["?"] * len(sequences))
431
  cursor.execute(
432
+ f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})",
433
  tuple(sequences),
434
  )
435
 
 
437
  for row in rows:
438
  sequence = row[0]
439
  embedding_bytes = row[1]
440
+ loaded[sequence] = embedding_blob_to_tensor(embedding_bytes)
 
 
 
 
 
 
 
 
 
 
 
441
  return loaded
442
 
443
  def embed_dataset(
 
456
  sql_db_path: str = 'embeddings.db',
457
  save_path: str = 'embeddings.pth',
458
  fasta_path: Optional[str] = None,
459
+ padding: str = 'max_length',
460
  **kwargs,
461
  ) -> Optional[Dict[str, torch.Tensor]]:
462
  """
 
479
  hidden_size = self.config.hidden_size
480
  pooler = Pooler(pooling_types) if not full_embeddings else None
481
  tokenizer_mode = tokenizer is not None
482
+
483
+ # Resolve padding and compilation
484
+ dynamic = padding == 'longest'
485
+ compiled_model = maybe_compile(self, dynamic=dynamic)
486
+
487
  if tokenizer_mode:
488
+ collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len)
489
  device = self.device
490
  else:
491
  collate_fn = None
 
502
  assert collate_fn is not None
503
  assert device is not None
504
  dataset = ProteinDataset(to_embed)
505
+ dataloader = DataLoader(
506
+ dataset,
507
+ batch_size=batch_size,
508
+ num_workers=num_workers,
509
+ prefetch_factor=2 if num_workers > 0 else None,
510
+ collate_fn=collate_fn,
511
+ shuffle=False,
512
+ pin_memory=True,
513
+ )
514
+ for i, batch in _make_embedding_progress(dataloader, padding):
515
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
516
  input_ids = batch['input_ids'].to(device)
517
  attention_mask = batch['attention_mask'].to(device)
518
+ residue_embeddings = compiled_model._embed(input_ids, attention_mask)
519
  yield seqs, residue_embeddings, attention_mask
520
  else:
521
  for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
522
  seqs = to_embed[batch_start:batch_start + batch_size]
523
+ batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs)
524
  assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
525
  assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
526
  residue_embeddings, attention_mask = batch_output
 
528
  yield seqs, residue_embeddings, attention_mask
529
 
530
  if sql:
531
+ conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
532
+ conn.execute('PRAGMA journal_mode=WAL')
533
+ conn.execute('PRAGMA busy_timeout=30000')
534
+ conn.execute('PRAGMA synchronous=OFF')
535
+ conn.execute('PRAGMA cache_size=-64000')
536
  self._ensure_embeddings_table(conn)
 
537
  already_embedded = self._read_sequences_from_db(sql_db_path)
538
  to_embed = [seq for seq in sequences if seq not in already_embedded]
539
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
540
  print(f"Embedding {len(to_embed)} new sequences")
541
  if len(to_embed) > 0:
542
+ sql_queue: queue.Queue = queue.Queue(maxsize=4)
543
+
544
+ def _sql_writer():
545
+ wc = conn.cursor()
546
+ while True:
547
+ item = sql_queue.get()
548
+ if item is None:
549
+ break
550
+ wc.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
551
+ if sql_queue.qsize() == 0:
 
 
 
 
552
  conn.commit()
553
+ conn.commit()
554
+
555
+ sql_writer_thread = threading.Thread(target=_sql_writer, daemon=True)
556
+ sql_writer_thread.start()
557
+
558
+ with torch.inference_mode():
559
+ for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
560
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
561
+ if full_embeddings:
562
+ batch_rows = []
563
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
564
+ batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size))))
565
+ else:
566
+ blobs = batch_tensor_to_blobs(embeddings)
567
+ batch_rows = list(zip(seqs, blobs))
568
+ sql_queue.put(batch_rows)
569
+
570
+ sql_queue.put(None)
571
+ sql_writer_thread.join()
572
  conn.close()
573
  return None
574
 
 
583
  print(f"Embedding {len(to_embed)} new sequences")
584
 
585
  if len(to_embed) > 0:
586
+ with torch.inference_mode():
587
  for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
588
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
589
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
 
2393
  # Ignore copy
2394
  def forward(
2395
  self,
2396
+ input_ids: Optional[torch.LongTensor] = None,
2397
+ within_seq_position_ids: Optional[torch.LongTensor] = None,
2398
+ global_position_ids: Optional[torch.LongTensor] = None,
2399
+ sequence_ids: Optional[torch.LongTensor] = None,
2400
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2401
  past_key_values: Optional[DynamicCache] = None,
2402
  use_cache: bool = False,
2403
  output_attentions: bool = False,
 
2420
  This tensor contains the sequence id of each residue.
2421
  For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
2422
  the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
2423
+ inputs_embeds: (batch_size, seq_length, hidden_size) - pre-computed embeddings,
2424
+ bypasses embed_tokens and embed_seq_id when provided. Used by PDE for
2425
+ differentiable soft sequence optimization.
2426
  past_key_values: DynamicCache
2427
  use_cache: bool
2428
  output_attentions: bool
 
2432
  Returns:
2433
  E1ModelOutputWithPast: Model Outputs
2434
  """
2435
+ assert not (input_ids is not None and inputs_embeds is not None), (
2436
+ "Cannot specify both input_ids and inputs_embeds"
2437
+ )
2438
+ assert input_ids is not None or inputs_embeds is not None, (
2439
+ "Must specify either input_ids or inputs_embeds"
2440
+ )
2441
+
2442
+ if input_ids is not None:
2443
+ batch_size, seq_length = input_ids.shape
2444
+ else:
2445
+ batch_size, seq_length = inputs_embeds.shape[:2]
2446
 
2447
  if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
2448
  if use_cache:
 
2456
  elif not use_cache:
2457
  past_key_values = None
2458
 
2459
+ # Synthesize positional IDs for soft embedding path (single-sequence)
2460
+ if inputs_embeds is not None:
2461
+ device = inputs_embeds.device
2462
+ if within_seq_position_ids is None:
2463
+ within_seq_position_ids = torch.arange(seq_length, device=device).unsqueeze(0).expand(batch_size, -1)
2464
+ if global_position_ids is None:
2465
+ global_position_ids = torch.arange(seq_length, device=device).unsqueeze(0).expand(batch_size, -1)
2466
+ if sequence_ids is None:
2467
+ sequence_ids = torch.zeros(batch_size, seq_length, device=device, dtype=torch.long)
2468
+
2469
  global_position_ids = global_position_ids.view(-1, seq_length).long()
2470
  within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long()
2471
  sequence_ids = sequence_ids.view(-1, seq_length).long()
 
2476
  f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}"
2477
  )
2478
 
2479
+ if inputs_embeds is None:
2480
+ inputs_embeds = self.embed_tokens(input_ids)
2481
+ inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0))
2482
 
2483
  if torch.is_autocast_enabled():
2484
  target_dtype = torch.get_autocast_gpu_dtype()
 
2590
 
2591
  def forward(
2592
  self,
2593
+ input_ids: Optional[torch.LongTensor] = None,
2594
+ within_seq_position_ids: Optional[torch.LongTensor] = None,
2595
+ global_position_ids: Optional[torch.LongTensor] = None,
2596
+ sequence_ids: Optional[torch.LongTensor] = None,
2597
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2598
  past_key_values: Optional[DynamicCache] = None,
2599
  use_cache: bool = False,
2600
  output_attentions: bool = False,
 
2607
  within_seq_position_ids=within_seq_position_ids,
2608
  global_position_ids=global_position_ids,
2609
  sequence_ids=sequence_ids,
2610
+ inputs_embeds=inputs_embeds,
2611
  past_key_values=past_key_values,
2612
  use_cache=use_cache,
2613
  output_attentions=output_attentions,
 
2650
 
2651
  def forward(
2652
  self,
2653
+ input_ids: Optional[torch.LongTensor] = None,
2654
+ within_seq_position_ids: Optional[torch.LongTensor] = None,
2655
+ global_position_ids: Optional[torch.LongTensor] = None,
2656
+ sequence_ids: Optional[torch.LongTensor] = None,
2657
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2658
  labels: Optional[torch.LongTensor] = None,
2659
  past_key_values: Optional[DynamicCache] = None,
2660
  use_cache: bool = False,
 
2678
  This tensor contains the sequence id of each residue.
2679
  For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
2680
  the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
2681
+ inputs_embeds: (batch_size, seq_length, hidden_size) - pre-computed embeddings
2682
  labels: (batch_size, seq_length)
2683
  past_key_values: DynamicCache
2684
  use_cache: bool
 
2694
  within_seq_position_ids=within_seq_position_ids,
2695
  global_position_ids=global_position_ids,
2696
  sequence_ids=sequence_ids,
2697
+ inputs_embeds=inputs_embeds,
2698
  past_key_values=past_key_values,
2699
  use_cache=use_cache,
2700
  output_attentions=output_attentions,
 
2772
 
2773
  def forward(
2774
  self,
2775
+ input_ids: Optional[torch.LongTensor] = None,
2776
+ within_seq_position_ids: Optional[torch.LongTensor] = None,
2777
+ global_position_ids: Optional[torch.LongTensor] = None,
2778
+ sequence_ids: Optional[torch.LongTensor] = None,
2779
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2780
  labels: Optional[torch.LongTensor] = None,
2781
  past_key_values: Optional[DynamicCache] = None,
2782
  use_cache: bool = False,
 
2790
  within_seq_position_ids=within_seq_position_ids,
2791
  global_position_ids=global_position_ids,
2792
  sequence_ids=sequence_ids,
2793
+ inputs_embeds=inputs_embeds,
2794
  past_key_values=past_key_values,
2795
  use_cache=use_cache,
2796
  output_attentions=output_attentions,
 
2798
  output_s_max=output_s_max,
2799
  )
2800
 
2801
+ attention_mask = (sequence_ids != -1).long() if sequence_ids is not None else torch.ones(outputs.last_hidden_state.shape[:2], device=outputs.last_hidden_state.device, dtype=torch.long)
2802
  x = outputs.last_hidden_state
2803
  features = self.pooler(x, attention_mask)
2804
  logits = self.classifier(features)
 
2869
 
2870
  def forward(
2871
  self,
2872
+ input_ids: Optional[torch.LongTensor] = None,
2873
+ within_seq_position_ids: Optional[torch.LongTensor] = None,
2874
+ global_position_ids: Optional[torch.LongTensor] = None,
2875
+ sequence_ids: Optional[torch.LongTensor] = None,
2876
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2877
  labels: Optional[torch.LongTensor] = None,
2878
  past_key_values: Optional[DynamicCache] = None,
2879
  use_cache: bool = False,
 
2887
  within_seq_position_ids=within_seq_position_ids,
2888
  global_position_ids=global_position_ids,
2889
  sequence_ids=sequence_ids,
2890
+ inputs_embeds=inputs_embeds,
2891
  past_key_values=past_key_values,
2892
  use_cache=use_cache,
2893
  output_attentions=output_attentions,