lhallee commited on
Commit
a9471d6
·
verified ·
1 Parent(s): fd58ef0

Upload modeling_ankh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_ankh.py +746 -10
modeling_ankh.py CHANGED
@@ -1,5 +1,751 @@
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import math
4
 
5
  import torch
@@ -10,16 +756,6 @@ from dataclasses import dataclass
10
  from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
11
  from transformers.modeling_outputs import ModelOutput
12
 
13
- try:
14
- from fastplms.attention import (
15
- AttentionBackend, VALID_ATTENTION_BACKENDS,
16
- resolve_attention_backend, get_attention_mask,
17
- _get_flex_attention_fn,
18
- create_block_mask, flex_attention, BlockMask,
19
- )
20
- from fastplms.embedding_mixin import Pooler, EmbeddingMixin, ProteinDataset, parse_fasta, build_collator
21
- except ImportError:
22
- pass # Running as HF Hub composite; shared definitions are above
23
 
24
 
25
  # ---------------------------------------------------------------------------
 
1
  from __future__ import annotations
2
 
3
+ import torch
4
+ import torch._inductor.config as inductor_config
5
+ import torch._dynamo as dynamo
6
+
7
+ # Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs)
8
+ # Provides significant speedup with minimal precision loss
9
+ torch.set_float32_matmul_precision('high')
10
+
11
+ # Enable TF32 for matrix multiplications and cuDNN operations
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True
14
+
15
+ # Enable cuDNN autotuner - finds fastest algorithms for your hardware
16
+ # Best when input sizes are consistent; may slow down first iterations
17
+ torch.backends.cudnn.benchmark = True
18
+
19
+ # Deterministic operations off for speed (set True if reproducibility needed)
20
+ torch.backends.cudnn.deterministic = False
21
+ inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
22
+
23
+ dynamo.config.capture_scalar_outputs = True
24
+ torch._dynamo.config.recompile_limit = 16
25
+
26
+ import os
27
+ import sqlite3
28
+ import networkx as nx
29
+ import numpy as np
30
+ import torch
31
+ from tqdm.auto import tqdm
32
+ from typing import Callable, Dict, List, Optional, Set
33
+ from torch.utils.data import DataLoader
34
+ from torch.utils.data import Dataset as TorchDataset
35
+ from transformers import PreTrainedTokenizerBase
36
+
37
+
38
+ class Pooler:
39
+ def __init__(self, pooling_types: List[str]) -> None:
40
+ self.pooling_types = pooling_types
41
+ self.pooling_options: Dict[str, Callable] = {
42
+ 'mean': self.mean_pooling,
43
+ 'max': self.max_pooling,
44
+ 'norm': self.norm_pooling,
45
+ 'median': self.median_pooling,
46
+ 'std': self.std_pooling,
47
+ 'var': self.var_pooling,
48
+ 'cls': self.cls_pooling,
49
+ 'parti': self._pool_parti,
50
+ }
51
+
52
+ def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
53
+ assert isinstance(attentions, torch.Tensor)
54
+ maxed_attentions = torch.max(attentions, dim=1)[0]
55
+ return maxed_attentions
56
+
57
+ def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
58
+ # Run PageRank on the attention matrix converted to a graph.
59
+ # Raises exceptions if the graph doesn't match the token sequence or has no edges.
60
+ # Returns the PageRank scores for each token node.
61
+ G = self._convert_to_graph(attention_matrix)
62
+ if G.number_of_nodes() != attention_matrix.shape[0]:
63
+ raise Exception(
64
+ f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
65
+ if G.number_of_edges() == 0:
66
+ raise Exception(f"You don't seem to have any attention edges left in the graph.")
67
+
68
+ return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
69
+
70
+ def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
71
+ # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
72
+ # Each element in the matrix represents a directed edge with a weight.
73
+ G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
74
+ return G
75
+
76
+ def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
77
+ # Remove keys where attention_mask is 0
78
+ if attention_mask is not None:
79
+ for k in list(dict_importance.keys()):
80
+ if attention_mask[k] == 0:
81
+ del dict_importance[k]
82
+
83
+ #dict_importance[0] # remove cls
84
+ #dict_importance[-1] # remove eos
85
+ total = sum(dict_importance.values())
86
+ return np.array([v / total for _, v in dict_importance.items()])
87
+
88
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
89
+ maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
90
+ # emb is (b, L, d), maxed_attentions is (b, L, L)
91
+ emb_pooled = []
92
+ for e, a, mask in zip(emb, maxed_attentions, attention_mask):
93
+ dict_importance = self._page_rank(a)
94
+ importance_weights = self._calculate_importance_weights(dict_importance, mask)
95
+ num_tokens = int(mask.sum().item())
96
+ emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
97
+ pooled = torch.tensor(np.array(emb_pooled))
98
+ return pooled
99
+
100
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
101
+ if attention_mask is None:
102
+ return emb.mean(dim=1)
103
+ else:
104
+ attention_mask = attention_mask.unsqueeze(-1)
105
+ return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
106
+
107
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
108
+ if attention_mask is None:
109
+ return emb.max(dim=1).values
110
+ else:
111
+ mask = attention_mask.unsqueeze(-1).bool()
112
+ return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
113
+
114
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
115
+ if attention_mask is None:
116
+ return emb.norm(dim=1, p=2)
117
+ else:
118
+ attention_mask = attention_mask.unsqueeze(-1)
119
+ return (emb * attention_mask).norm(dim=1, p=2)
120
+
121
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
122
+ if attention_mask is None:
123
+ return emb.median(dim=1).values
124
+ else:
125
+ mask = attention_mask.unsqueeze(-1).bool()
126
+ return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
127
+
128
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
129
+ if attention_mask is None:
130
+ return emb.std(dim=1)
131
+ else:
132
+ # Compute variance correctly over non-masked positions, then take sqrt
133
+ var = self.var_pooling(emb, attention_mask, **kwargs)
134
+ return torch.sqrt(var)
135
+
136
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
137
+ if attention_mask is None:
138
+ return emb.var(dim=1)
139
+ else:
140
+ # Correctly compute variance over only non-masked positions
141
+ attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
142
+ # Compute mean over non-masked positions
143
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
144
+ mean = mean.unsqueeze(1) # (b, 1, d)
145
+ # Compute squared differences from mean, only over non-masked positions
146
+ squared_diff = (emb - mean) ** 2 # (b, L, d)
147
+ # Sum squared differences over non-masked positions and divide by count
148
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
149
+ return var
150
+
151
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
152
+ return emb[:, 0, :]
153
+
154
+ def __call__(
155
+ self,
156
+ emb: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ attentions: Optional[torch.Tensor] = None
159
+ ) -> torch.Tensor: # [mean, max]
160
+ final_emb: List[torch.Tensor] = []
161
+ for pooling_type in self.pooling_types:
162
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
163
+ return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
164
+
165
+
166
+ class ProteinDataset(TorchDataset):
167
+ """Simple dataset for protein sequences."""
168
+ def __init__(self, sequences: List[str]) -> None:
169
+ self.sequences = sequences
170
+
171
+ def __len__(self) -> int:
172
+ return len(self.sequences)
173
+
174
+ def __getitem__(self, idx: int) -> str:
175
+ return self.sequences[idx]
176
+
177
+
178
+ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
179
+ def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
180
+ return tokenizer(sequences, return_tensors="pt", padding='longest')
181
+ return _collate_fn
182
+
183
+
184
+ def parse_fasta(fasta_path: str) -> List[str]:
185
+ assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
186
+ sequences = []
187
+ current_seq = []
188
+ with open(fasta_path, 'r') as f:
189
+ for line in f:
190
+ line = line.strip()
191
+ if not line:
192
+ continue
193
+ if line.startswith('>'):
194
+ if current_seq:
195
+ sequences.append(''.join(current_seq))
196
+ current_seq = []
197
+ else:
198
+ current_seq.append(line)
199
+ if current_seq:
200
+ sequences.append(''.join(current_seq))
201
+ return sequences
202
+
203
+
204
+ class EmbeddingMixin:
205
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ raise NotImplementedError
207
+
208
+ @property
209
+ def device(self) -> torch.device:
210
+ """Get the device of the model."""
211
+ return next(self.parameters()).device
212
+
213
+ def _read_sequences_from_db(self, db_path: str) -> Set[str]:
214
+ """Read sequences from SQLite database."""
215
+ sequences = []
216
+ with sqlite3.connect(db_path) as conn:
217
+ c = conn.cursor()
218
+ c.execute("SELECT sequence FROM embeddings")
219
+ while True:
220
+ row = c.fetchone()
221
+ if row is None:
222
+ break
223
+ sequences.append(row[0])
224
+ return set(sequences)
225
+
226
+ def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
227
+ cursor = conn.cursor()
228
+ cursor.execute(
229
+ "CREATE TABLE IF NOT EXISTS embeddings ("
230
+ "sequence TEXT PRIMARY KEY, "
231
+ "embedding BLOB NOT NULL, "
232
+ "shape TEXT, "
233
+ "dtype TEXT"
234
+ ")"
235
+ )
236
+ cursor.execute("PRAGMA table_info(embeddings)")
237
+ rows = cursor.fetchall()
238
+ column_names = [row[1] for row in rows]
239
+ if "shape" not in column_names:
240
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
241
+ if "dtype" not in column_names:
242
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
243
+ conn.commit()
244
+
245
+ def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
246
+ assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
247
+ payload = torch.load(save_path, map_location="cpu", weights_only=True)
248
+ assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
249
+ for sequence, tensor in payload.items():
250
+ assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
251
+ assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
252
+ return payload
253
+
254
+ def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
255
+ assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
256
+ loaded: Dict[str, torch.Tensor] = {}
257
+ with sqlite3.connect(db_path) as conn:
258
+ self._ensure_embeddings_table(conn)
259
+ cursor = conn.cursor()
260
+ if sequences is None:
261
+ cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
262
+ else:
263
+ if len(sequences) == 0:
264
+ return loaded
265
+ placeholders = ",".join(["?"] * len(sequences))
266
+ cursor.execute(
267
+ f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
268
+ tuple(sequences),
269
+ )
270
+
271
+ rows = cursor.fetchall()
272
+ for row in rows:
273
+ sequence = row[0]
274
+ embedding_bytes = row[1]
275
+ shape_text = row[2]
276
+ dtype_text = row[3]
277
+ assert shape_text is not None, "Missing shape metadata in embeddings table."
278
+ assert dtype_text is not None, "Missing dtype metadata in embeddings table."
279
+ shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
280
+ assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
281
+ expected_size = int(np.prod(shape_values))
282
+ np_dtype = np.dtype(dtype_text)
283
+ array = np.frombuffer(embedding_bytes, dtype=np_dtype)
284
+ assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
285
+ reshaped = array.copy().reshape(tuple(shape_values))
286
+ loaded[sequence] = torch.from_numpy(reshaped)
287
+ return loaded
288
+
289
+ def embed_dataset(
290
+ self,
291
+ sequences: Optional[List[str]] = None,
292
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
293
+ batch_size: int = 2,
294
+ max_len: int = 512,
295
+ truncate: bool = True,
296
+ full_embeddings: bool = False,
297
+ embed_dtype: torch.dtype = torch.float32,
298
+ pooling_types: List[str] = ['mean'],
299
+ num_workers: int = 0,
300
+ sql: bool = False,
301
+ save: bool = True,
302
+ sql_db_path: str = 'embeddings.db',
303
+ save_path: str = 'embeddings.pth',
304
+ fasta_path: Optional[str] = None,
305
+ **kwargs,
306
+ ) -> Optional[Dict[str, torch.Tensor]]:
307
+ """
308
+ Embed a dataset of protein sequences.
309
+
310
+ Supports two modes:
311
+ - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
312
+ - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
313
+
314
+ Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
315
+ `fasta_path`, or both (the two sources are combined). At least one must be provided.
316
+ """
317
+ if fasta_path is not None:
318
+ fasta_sequences = parse_fasta(fasta_path)
319
+ sequences = list(sequences or []) + fasta_sequences
320
+ assert sequences is not None and len(sequences) > 0, \
321
+ "Must provide at least one sequence via `sequences` or `fasta_path`."
322
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
323
+ sequences = sorted(sequences, key=len, reverse=True)
324
+ hidden_size = self.config.hidden_size
325
+ pooler = Pooler(pooling_types) if not full_embeddings else None
326
+ tokenizer_mode = tokenizer is not None
327
+ if tokenizer_mode:
328
+ collate_fn = build_collator(tokenizer)
329
+ device = self.device
330
+ else:
331
+ collate_fn = None
332
+ device = None
333
+
334
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335
+ assert isinstance(residue_embeddings, torch.Tensor)
336
+ if full_embeddings or residue_embeddings.ndim == 2:
337
+ return residue_embeddings
338
+ return pooler(residue_embeddings, attention_mask)
339
+
340
+ def iter_batches(to_embed: List[str]):
341
+ if tokenizer_mode:
342
+ assert collate_fn is not None
343
+ assert device is not None
344
+ dataset = ProteinDataset(to_embed)
345
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
346
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
347
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
348
+ input_ids = batch['input_ids'].to(device)
349
+ attention_mask = batch['attention_mask'].to(device)
350
+ residue_embeddings = self._embed(input_ids, attention_mask)
351
+ yield seqs, residue_embeddings, attention_mask
352
+ else:
353
+ for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
354
+ seqs = to_embed[batch_start:batch_start + batch_size]
355
+ batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
356
+ assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
357
+ assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
358
+ residue_embeddings, attention_mask = batch_output
359
+ assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
360
+ yield seqs, residue_embeddings, attention_mask
361
+
362
+ if sql:
363
+ conn = sqlite3.connect(sql_db_path)
364
+ self._ensure_embeddings_table(conn)
365
+ c = conn.cursor()
366
+ already_embedded = self._read_sequences_from_db(sql_db_path)
367
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
368
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
369
+ print(f"Embedding {len(to_embed)} new sequences")
370
+ if len(to_embed) > 0:
371
+ with torch.no_grad():
372
+ for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
373
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
374
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
375
+ if full_embeddings:
376
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
377
+ emb_np = emb.cpu().numpy()
378
+ emb_shape = ",".join([str(dim) for dim in emb_np.shape])
379
+ emb_dtype = str(emb_np.dtype)
380
+ c.execute(
381
+ "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
382
+ (seq, emb_np.tobytes(), emb_shape, emb_dtype),
383
+ )
384
+ if tokenizer_mode and (i + 1) % 100 == 0:
385
+ conn.commit()
386
+ conn.commit()
387
+ conn.close()
388
+ return None
389
+
390
+ embeddings_dict = {}
391
+ if os.path.exists(save_path):
392
+ embeddings_dict = self.load_embeddings_from_pth(save_path)
393
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
394
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
395
+ print(f"Embedding {len(to_embed)} new sequences")
396
+ else:
397
+ to_embed = sequences
398
+ print(f"Embedding {len(to_embed)} new sequences")
399
+
400
+ if len(to_embed) > 0:
401
+ with torch.no_grad():
402
+ for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
403
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
404
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
405
+ if full_embeddings:
406
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
407
+ embeddings_dict[seq] = emb.cpu()
408
+
409
+ if save:
410
+ torch.save(embeddings_dict, save_path)
411
+
412
+ return embeddings_dict
413
+
414
+
415
+ if __name__ == "__main__":
416
+ # py -m pooler
417
+ pooler = Pooler(pooling_types=['max', 'parti'])
418
+ batch_size = 8
419
+ seq_len = 64
420
+ hidden_size = 128
421
+ num_layers = 12
422
+ emb = torch.randn(batch_size, seq_len, hidden_size)
423
+ attentions = torch.randn(batch_size, num_layers, seq_len, seq_len)
424
+ attention_mask = torch.ones(batch_size, seq_len)
425
+ y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions)
426
+ print(y.shape)
427
+
428
+ """Shared attention infrastructure for all FastPLMs models.
429
+
430
+ Contains: AttentionBackend enum, backend resolution, mask creation,
431
+ flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
+ """
433
+ from enum import Enum
434
+ from functools import partial
435
+ from typing import Dict, List, Optional, Tuple
436
+
437
+ import torch
438
+ import torch.nn as nn
439
+ from torch.nn import functional as F
440
+ from einops import rearrange
441
+
442
+ try:
443
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
444
+ except ImportError:
445
+ create_block_mask = None
446
+ flex_attention = None
447
+ BlockMask = None
448
+
449
+ _compiled_flex_attention = None
450
+
451
+
452
+ def _get_flex_attention_fn():
453
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
+
455
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
+ on older hardware.
458
+ """
459
+ global _compiled_flex_attention
460
+ if flex_attention is None:
461
+ return None
462
+ flex_mod = torch.nn.attention.flex_attention
463
+ if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
+ return flex_attention
465
+ if _compiled_flex_attention is None:
466
+ _compiled_flex_attention = torch.compile(
467
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
+ dynamic=False,
469
+ )
470
+ return _compiled_flex_attention
471
+
472
+
473
+ ### Kernels Flash Attention Detection
474
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
+ if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
+ return "flash_attn2"
477
+ if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
478
+ return "flash_attn3"
479
+ return None
480
+
481
+
482
+ def _try_get_kernels_flash():
483
+ try:
484
+ from kernels import get_kernel
485
+ except ImportError:
486
+ return None, None
487
+
488
+ flash_kernel = None
489
+ flash_kernel_variant = None
490
+ try:
491
+ flash_kernel = get_kernel("kernels-community/flash-attn3")
492
+ flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
493
+ assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
494
+ except Exception:
495
+ try:
496
+ flash_kernel = get_kernel("kernels-community/flash-attn2")
497
+ flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
498
+ assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
499
+ except Exception:
500
+ flash_kernel = None
501
+ flash_kernel_variant = None
502
+ return flash_kernel, flash_kernel_variant
503
+
504
+
505
+ _FLASH_KERNELS_LOADED = False
506
+ FLASH_KERNEL = None
507
+ FLASH_KERNEL_VARIANT = None
508
+
509
+
510
+ def _ensure_flash_kernels_loaded():
511
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
512
+ if _FLASH_KERNELS_LOADED:
513
+ return
514
+ _FLASH_KERNELS_LOADED = True
515
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
516
+
517
+
518
+ def _kernels_flash_forward(
519
+ query_states: torch.Tensor,
520
+ key_states: torch.Tensor,
521
+ value_states: torch.Tensor,
522
+ causal: bool = False,
523
+ ) -> torch.Tensor:
524
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
525
+ if FLASH_KERNEL_VARIANT == "flash_attn2":
526
+ return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
527
+ if FLASH_KERNEL_VARIANT == "flash_attn3":
528
+ try:
529
+ output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
530
+ except TypeError:
531
+ output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
532
+ if isinstance(output, tuple):
533
+ return output[0]
534
+ return output
535
+ raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
536
+
537
+
538
+ def _kernels_flash_varlen_forward(
539
+ query_states: torch.Tensor,
540
+ key_states: torch.Tensor,
541
+ value_states: torch.Tensor,
542
+ cu_seqlens_q: torch.Tensor,
543
+ cu_seqlens_k: torch.Tensor,
544
+ max_seqlen_in_batch_q: int,
545
+ max_seqlen_in_batch_k: int,
546
+ causal: bool = False,
547
+ ) -> torch.Tensor:
548
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
549
+ if FLASH_KERNEL_VARIANT == "flash_attn2":
550
+ return FLASH_KERNEL.varlen_fwd(
551
+ q=query_states, k=key_states, v=value_states,
552
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
553
+ max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
554
+ is_causal=causal,
555
+ )[0]
556
+ if FLASH_KERNEL_VARIANT == "flash_attn3":
557
+ try:
558
+ output = FLASH_KERNEL.flash_attn_varlen_func(
559
+ q=query_states, k=key_states, v=value_states,
560
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
561
+ max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
562
+ causal=causal,
563
+ )
564
+ except TypeError:
565
+ output = FLASH_KERNEL.flash_attn_varlen_func(
566
+ query_states, key_states, value_states,
567
+ cu_seqlens_q, cu_seqlens_k,
568
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k,
569
+ 0.0, None, causal,
570
+ )
571
+ if isinstance(output, tuple):
572
+ return output[0]
573
+ return output
574
+ raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
575
+
576
+
577
+ ### Unpad / Pad helpers for varlen flash attention
578
+ class IndexFirstAxis(torch.autograd.Function):
579
+ @staticmethod
580
+ def forward(ctx, input, indices) -> torch.Tensor:
581
+ ctx.save_for_backward(indices)
582
+ assert input.ndim >= 2
583
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
584
+ second_dim = other_shape.numel()
585
+ return torch.gather(
586
+ rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
587
+ ).reshape(-1, *other_shape)
588
+
589
+ @staticmethod
590
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
+ (indices,) = ctx.saved_tensors
592
+ assert grad_output.ndim >= 2
593
+ other_shape = grad_output.shape[1:]
594
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
595
+ grad_input = torch.zeros(
596
+ [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
597
+ )
598
+ grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
599
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
600
+
601
+
602
+ class IndexPutFirstAxis(torch.autograd.Function):
603
+ @staticmethod
604
+ def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
605
+ ctx.save_for_backward(indices)
606
+ assert indices.ndim == 1
607
+ assert values.ndim >= 2
608
+ output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
609
+ output[indices] = values
610
+ return output
611
+
612
+ @staticmethod
613
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
+ (indices,) = ctx.saved_tensors
615
+ return grad_output[indices], None, None
616
+
617
+
618
+ index_first_axis = IndexFirstAxis.apply
619
+ index_put_first_axis = IndexPutFirstAxis.apply
620
+
621
+
622
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
623
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
624
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
625
+
626
+
627
+ def _unpad_input(
628
+ query_layer: torch.Tensor,
629
+ key_layer: torch.Tensor,
630
+ value_layer: torch.Tensor,
631
+ attention_mask_2d: torch.Tensor,
632
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
+ batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
+ seqlens = attention_mask_2d.sum(dim=1).int()
635
+ cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
636
+ max_seqlen = int(seqlens.max().item())
637
+ indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
638
+ query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
639
+ key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
640
+ value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
641
+ return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
642
+
643
+
644
+ def kernels_flash_attention_func(
645
+ query_states: torch.Tensor,
646
+ key_states: torch.Tensor,
647
+ value_states: torch.Tensor,
648
+ attention_mask_2d: Optional[torch.Tensor] = None,
649
+ causal: bool = False,
650
+ ) -> torch.Tensor:
651
+ assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
652
+ if not causal and attention_mask_2d is not None:
653
+ batch_size, q_len = query_states.shape[:2]
654
+ (
655
+ query_states, key_states, value_states,
656
+ indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
657
+ ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
658
+ attn_output_unpad = _kernels_flash_varlen_forward(
659
+ query_states=query_states, key_states=key_states, value_states=value_states,
660
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
661
+ max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
662
+ )
663
+ return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
664
+ else:
665
+ return _kernels_flash_forward(
666
+ query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
667
+ )
668
+
669
+
670
+ ### Attention Backend Enum & Resolution
671
+ class AttentionBackend(Enum):
672
+ AUTO = "auto"
673
+ KERNELS_FLASH = "kernels_flash"
674
+ FLEX = "flex"
675
+ SDPA = "sdpa"
676
+
677
+
678
+ VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
679
+
680
+
681
+ _BACKEND_CONFIRMED = False
682
+
683
+
684
+ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
685
+ global _BACKEND_CONFIRMED
686
+ assert requested_backend in VALID_ATTENTION_BACKENDS, (
687
+ f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
688
+ )
689
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
690
+ _ensure_flash_kernels_loaded()
691
+ if requested_backend == AttentionBackend.AUTO.value:
692
+ if FLASH_KERNEL is not None:
693
+ resolved = AttentionBackend.KERNELS_FLASH
694
+ elif flex_attention is not None:
695
+ resolved = AttentionBackend.FLEX
696
+ else:
697
+ resolved = AttentionBackend.SDPA
698
+ elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
699
+ assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
700
+ resolved = AttentionBackend.KERNELS_FLASH
701
+ elif requested_backend == AttentionBackend.FLEX.value:
702
+ assert flex_attention is not None, "Flex Attention is not available in this environment."
703
+ resolved = AttentionBackend.FLEX
704
+ elif requested_backend == AttentionBackend.SDPA.value:
705
+ resolved = AttentionBackend.SDPA
706
+ else:
707
+ raise AssertionError(f"Unsupported attention backend: {requested_backend}")
708
+ if not _BACKEND_CONFIRMED:
709
+ print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
710
+ _BACKEND_CONFIRMED = True
711
+ return resolved
712
+
713
+
714
+ @torch.compiler.disable
715
+ def get_attention_mask(
716
+ effective_backend: AttentionBackend,
717
+ batch_size: int,
718
+ seq_len: int,
719
+ device: torch.device,
720
+ attention_mask: Optional[torch.Tensor] = None,
721
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
+ """Build padding masks once for all encoder layers.
723
+
724
+ Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
725
+ """
726
+ if attention_mask is None:
727
+ return None, None, None
728
+
729
+ attention_mask_2d = attention_mask.bool()
730
+
731
+ if effective_backend == AttentionBackend.KERNELS_FLASH:
732
+ return attention_mask_2d, None, None
733
+
734
+ if effective_backend == AttentionBackend.FLEX:
735
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
736
+ valid_lens = attention_mask_2d.sum(dim=-1)
737
+
738
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
739
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
740
+
741
+ flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
742
+ return attention_mask_2d, None, flex_block_mask
743
+
744
+ # SDPA / manual -- only mask the key dimension so padding query positions attend to
745
+ # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
746
+ attention_mask_4d = attention_mask_2d[:, None, None, :]
747
+ return attention_mask_2d, attention_mask_4d, None
748
+
749
  import math
750
 
751
  import torch
 
756
  from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
757
  from transformers.modeling_outputs import ModelOutput
758
 
 
 
 
 
 
 
 
 
 
 
759
 
760
 
761
  # ---------------------------------------------------------------------------