lhallee commited on
Commit
fd58ef0
·
verified ·
1 Parent(s): a9d284d

Upload modeling_ankh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_ankh.py +10 -746
modeling_ankh.py CHANGED
@@ -1,751 +1,5 @@
1
  from __future__ import annotations
2
 
3
- import torch
4
- import torch._inductor.config as inductor_config
5
- import torch._dynamo as dynamo
6
-
7
- # Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs)
8
- # Provides significant speedup with minimal precision loss
9
- torch.set_float32_matmul_precision('high')
10
-
11
- # Enable TF32 for matrix multiplications and cuDNN operations
12
- torch.backends.cuda.matmul.allow_tf32 = True
13
- torch.backends.cudnn.allow_tf32 = True
14
-
15
- # Enable cuDNN autotuner - finds fastest algorithms for your hardware
16
- # Best when input sizes are consistent; may slow down first iterations
17
- torch.backends.cudnn.benchmark = True
18
-
19
- # Deterministic operations off for speed (set True if reproducibility needed)
20
- torch.backends.cudnn.deterministic = False
21
- inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
22
-
23
- dynamo.config.capture_scalar_outputs = True
24
- torch._dynamo.config.recompile_limit = 16
25
-
26
- import os
27
- import sqlite3
28
- import networkx as nx
29
- import numpy as np
30
- import torch
31
- from tqdm.auto import tqdm
32
- from typing import Callable, Dict, List, Optional, Set
33
- from torch.utils.data import DataLoader
34
- from torch.utils.data import Dataset as TorchDataset
35
- from transformers import PreTrainedTokenizerBase
36
-
37
-
38
- class Pooler:
39
- def __init__(self, pooling_types: List[str]) -> None:
40
- self.pooling_types = pooling_types
41
- self.pooling_options: Dict[str, Callable] = {
42
- 'mean': self.mean_pooling,
43
- 'max': self.max_pooling,
44
- 'norm': self.norm_pooling,
45
- 'median': self.median_pooling,
46
- 'std': self.std_pooling,
47
- 'var': self.var_pooling,
48
- 'cls': self.cls_pooling,
49
- 'parti': self._pool_parti,
50
- }
51
-
52
- def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
53
- assert isinstance(attentions, torch.Tensor)
54
- maxed_attentions = torch.max(attentions, dim=1)[0]
55
- return maxed_attentions
56
-
57
- def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
58
- # Run PageRank on the attention matrix converted to a graph.
59
- # Raises exceptions if the graph doesn't match the token sequence or has no edges.
60
- # Returns the PageRank scores for each token node.
61
- G = self._convert_to_graph(attention_matrix)
62
- if G.number_of_nodes() != attention_matrix.shape[0]:
63
- raise Exception(
64
- f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
65
- if G.number_of_edges() == 0:
66
- raise Exception(f"You don't seem to have any attention edges left in the graph.")
67
-
68
- return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
69
-
70
- def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
71
- # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
72
- # Each element in the matrix represents a directed edge with a weight.
73
- G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
74
- return G
75
-
76
- def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
77
- # Remove keys where attention_mask is 0
78
- if attention_mask is not None:
79
- for k in list(dict_importance.keys()):
80
- if attention_mask[k] == 0:
81
- del dict_importance[k]
82
-
83
- #dict_importance[0] # remove cls
84
- #dict_importance[-1] # remove eos
85
- total = sum(dict_importance.values())
86
- return np.array([v / total for _, v in dict_importance.items()])
87
-
88
- def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d)
89
- maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
90
- # emb is (b, L, d), maxed_attentions is (b, L, L)
91
- emb_pooled = []
92
- for e, a, mask in zip(emb, maxed_attentions, attention_mask):
93
- dict_importance = self._page_rank(a)
94
- importance_weights = self._calculate_importance_weights(dict_importance, mask)
95
- num_tokens = int(mask.sum().item())
96
- emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
97
- pooled = torch.tensor(np.array(emb_pooled))
98
- return pooled
99
-
100
- def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
101
- if attention_mask is None:
102
- return emb.mean(dim=1)
103
- else:
104
- attention_mask = attention_mask.unsqueeze(-1)
105
- return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
106
-
107
- def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
108
- if attention_mask is None:
109
- return emb.max(dim=1).values
110
- else:
111
- mask = attention_mask.unsqueeze(-1).bool()
112
- return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
113
-
114
- def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
115
- if attention_mask is None:
116
- return emb.norm(dim=1, p=2)
117
- else:
118
- attention_mask = attention_mask.unsqueeze(-1)
119
- return (emb * attention_mask).norm(dim=1, p=2)
120
-
121
- def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
122
- if attention_mask is None:
123
- return emb.median(dim=1).values
124
- else:
125
- mask = attention_mask.unsqueeze(-1).bool()
126
- return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
127
-
128
- def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
129
- if attention_mask is None:
130
- return emb.std(dim=1)
131
- else:
132
- # Compute variance correctly over non-masked positions, then take sqrt
133
- var = self.var_pooling(emb, attention_mask, **kwargs)
134
- return torch.sqrt(var)
135
-
136
- def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
137
- if attention_mask is None:
138
- return emb.var(dim=1)
139
- else:
140
- # Correctly compute variance over only non-masked positions
141
- attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
142
- # Compute mean over non-masked positions
143
- mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
144
- mean = mean.unsqueeze(1) # (b, 1, d)
145
- # Compute squared differences from mean, only over non-masked positions
146
- squared_diff = (emb - mean) ** 2 # (b, L, d)
147
- # Sum squared differences over non-masked positions and divide by count
148
- var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
149
- return var
150
-
151
- def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d)
152
- return emb[:, 0, :]
153
-
154
- def __call__(
155
- self,
156
- emb: torch.Tensor,
157
- attention_mask: Optional[torch.Tensor] = None,
158
- attentions: Optional[torch.Tensor] = None
159
- ) -> torch.Tensor: # [mean, max]
160
- final_emb: List[torch.Tensor] = []
161
- for pooling_type in self.pooling_types:
162
- final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
163
- return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
164
-
165
-
166
- class ProteinDataset(TorchDataset):
167
- """Simple dataset for protein sequences."""
168
- def __init__(self, sequences: List[str]) -> None:
169
- self.sequences = sequences
170
-
171
- def __len__(self) -> int:
172
- return len(self.sequences)
173
-
174
- def __getitem__(self, idx: int) -> str:
175
- return self.sequences[idx]
176
-
177
-
178
- def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
179
- def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
180
- return tokenizer(sequences, return_tensors="pt", padding='longest')
181
- return _collate_fn
182
-
183
-
184
- def parse_fasta(fasta_path: str) -> List[str]:
185
- assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
186
- sequences = []
187
- current_seq = []
188
- with open(fasta_path, 'r') as f:
189
- for line in f:
190
- line = line.strip()
191
- if not line:
192
- continue
193
- if line.startswith('>'):
194
- if current_seq:
195
- sequences.append(''.join(current_seq))
196
- current_seq = []
197
- else:
198
- current_seq.append(line)
199
- if current_seq:
200
- sequences.append(''.join(current_seq))
201
- return sequences
202
-
203
-
204
- class EmbeddingMixin:
205
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
- raise NotImplementedError
207
-
208
- @property
209
- def device(self) -> torch.device:
210
- """Get the device of the model."""
211
- return next(self.parameters()).device
212
-
213
- def _read_sequences_from_db(self, db_path: str) -> Set[str]:
214
- """Read sequences from SQLite database."""
215
- sequences = []
216
- with sqlite3.connect(db_path) as conn:
217
- c = conn.cursor()
218
- c.execute("SELECT sequence FROM embeddings")
219
- while True:
220
- row = c.fetchone()
221
- if row is None:
222
- break
223
- sequences.append(row[0])
224
- return set(sequences)
225
-
226
- def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
227
- cursor = conn.cursor()
228
- cursor.execute(
229
- "CREATE TABLE IF NOT EXISTS embeddings ("
230
- "sequence TEXT PRIMARY KEY, "
231
- "embedding BLOB NOT NULL, "
232
- "shape TEXT, "
233
- "dtype TEXT"
234
- ")"
235
- )
236
- cursor.execute("PRAGMA table_info(embeddings)")
237
- rows = cursor.fetchall()
238
- column_names = [row[1] for row in rows]
239
- if "shape" not in column_names:
240
- cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
241
- if "dtype" not in column_names:
242
- cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
243
- conn.commit()
244
-
245
- def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
246
- assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
247
- payload = torch.load(save_path, map_location="cpu", weights_only=True)
248
- assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
249
- for sequence, tensor in payload.items():
250
- assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
251
- assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
252
- return payload
253
-
254
- def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
255
- assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
256
- loaded: Dict[str, torch.Tensor] = {}
257
- with sqlite3.connect(db_path) as conn:
258
- self._ensure_embeddings_table(conn)
259
- cursor = conn.cursor()
260
- if sequences is None:
261
- cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
262
- else:
263
- if len(sequences) == 0:
264
- return loaded
265
- placeholders = ",".join(["?"] * len(sequences))
266
- cursor.execute(
267
- f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
268
- tuple(sequences),
269
- )
270
-
271
- rows = cursor.fetchall()
272
- for row in rows:
273
- sequence = row[0]
274
- embedding_bytes = row[1]
275
- shape_text = row[2]
276
- dtype_text = row[3]
277
- assert shape_text is not None, "Missing shape metadata in embeddings table."
278
- assert dtype_text is not None, "Missing dtype metadata in embeddings table."
279
- shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
280
- assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
281
- expected_size = int(np.prod(shape_values))
282
- np_dtype = np.dtype(dtype_text)
283
- array = np.frombuffer(embedding_bytes, dtype=np_dtype)
284
- assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
285
- reshaped = array.copy().reshape(tuple(shape_values))
286
- loaded[sequence] = torch.from_numpy(reshaped)
287
- return loaded
288
-
289
- def embed_dataset(
290
- self,
291
- sequences: Optional[List[str]] = None,
292
- tokenizer: Optional[PreTrainedTokenizerBase] = None,
293
- batch_size: int = 2,
294
- max_len: int = 512,
295
- truncate: bool = True,
296
- full_embeddings: bool = False,
297
- embed_dtype: torch.dtype = torch.float32,
298
- pooling_types: List[str] = ['mean'],
299
- num_workers: int = 0,
300
- sql: bool = False,
301
- save: bool = True,
302
- sql_db_path: str = 'embeddings.db',
303
- save_path: str = 'embeddings.pth',
304
- fasta_path: Optional[str] = None,
305
- **kwargs,
306
- ) -> Optional[Dict[str, torch.Tensor]]:
307
- """
308
- Embed a dataset of protein sequences.
309
-
310
- Supports two modes:
311
- - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
312
- - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
313
-
314
- Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
315
- `fasta_path`, or both (the two sources are combined). At least one must be provided.
316
- """
317
- if fasta_path is not None:
318
- fasta_sequences = parse_fasta(fasta_path)
319
- sequences = list(sequences or []) + fasta_sequences
320
- assert sequences is not None and len(sequences) > 0, \
321
- "Must provide at least one sequence via `sequences` or `fasta_path`."
322
- sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
323
- sequences = sorted(sequences, key=len, reverse=True)
324
- hidden_size = self.config.hidden_size
325
- pooler = Pooler(pooling_types) if not full_embeddings else None
326
- tokenizer_mode = tokenizer is not None
327
- if tokenizer_mode:
328
- collate_fn = build_collator(tokenizer)
329
- device = self.device
330
- else:
331
- collate_fn = None
332
- device = None
333
-
334
- def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
335
- assert isinstance(residue_embeddings, torch.Tensor)
336
- if full_embeddings or residue_embeddings.ndim == 2:
337
- return residue_embeddings
338
- return pooler(residue_embeddings, attention_mask)
339
-
340
- def iter_batches(to_embed: List[str]):
341
- if tokenizer_mode:
342
- assert collate_fn is not None
343
- assert device is not None
344
- dataset = ProteinDataset(to_embed)
345
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
346
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
347
- seqs = to_embed[i * batch_size:(i + 1) * batch_size]
348
- input_ids = batch['input_ids'].to(device)
349
- attention_mask = batch['attention_mask'].to(device)
350
- residue_embeddings = self._embed(input_ids, attention_mask)
351
- yield seqs, residue_embeddings, attention_mask
352
- else:
353
- for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
354
- seqs = to_embed[batch_start:batch_start + batch_size]
355
- batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
356
- assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
357
- assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
358
- residue_embeddings, attention_mask = batch_output
359
- assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
360
- yield seqs, residue_embeddings, attention_mask
361
-
362
- if sql:
363
- conn = sqlite3.connect(sql_db_path)
364
- self._ensure_embeddings_table(conn)
365
- c = conn.cursor()
366
- already_embedded = self._read_sequences_from_db(sql_db_path)
367
- to_embed = [seq for seq in sequences if seq not in already_embedded]
368
- print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
369
- print(f"Embedding {len(to_embed)} new sequences")
370
- if len(to_embed) > 0:
371
- with torch.no_grad():
372
- for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
373
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
374
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
375
- if full_embeddings:
376
- emb = emb[mask.bool()].reshape(-1, hidden_size)
377
- emb_np = emb.cpu().numpy()
378
- emb_shape = ",".join([str(dim) for dim in emb_np.shape])
379
- emb_dtype = str(emb_np.dtype)
380
- c.execute(
381
- "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
382
- (seq, emb_np.tobytes(), emb_shape, emb_dtype),
383
- )
384
- if tokenizer_mode and (i + 1) % 100 == 0:
385
- conn.commit()
386
- conn.commit()
387
- conn.close()
388
- return None
389
-
390
- embeddings_dict = {}
391
- if os.path.exists(save_path):
392
- embeddings_dict = self.load_embeddings_from_pth(save_path)
393
- to_embed = [seq for seq in sequences if seq not in embeddings_dict]
394
- print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
395
- print(f"Embedding {len(to_embed)} new sequences")
396
- else:
397
- to_embed = sequences
398
- print(f"Embedding {len(to_embed)} new sequences")
399
-
400
- if len(to_embed) > 0:
401
- with torch.no_grad():
402
- for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
403
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
404
- for seq, emb, mask in zip(seqs, embeddings, attention_mask):
405
- if full_embeddings:
406
- emb = emb[mask.bool()].reshape(-1, hidden_size)
407
- embeddings_dict[seq] = emb.cpu()
408
-
409
- if save:
410
- torch.save(embeddings_dict, save_path)
411
-
412
- return embeddings_dict
413
-
414
-
415
- if __name__ == "__main__":
416
- # py -m pooler
417
- pooler = Pooler(pooling_types=['max', 'parti'])
418
- batch_size = 8
419
- seq_len = 64
420
- hidden_size = 128
421
- num_layers = 12
422
- emb = torch.randn(batch_size, seq_len, hidden_size)
423
- attentions = torch.randn(batch_size, num_layers, seq_len, seq_len)
424
- attention_mask = torch.ones(batch_size, seq_len)
425
- y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions)
426
- print(y.shape)
427
-
428
- """Shared attention infrastructure for all FastPLMs models.
429
-
430
- Contains: AttentionBackend enum, backend resolution, mask creation,
431
- flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
432
- """
433
- from enum import Enum
434
- from functools import partial
435
- from typing import Dict, List, Optional, Tuple
436
-
437
- import torch
438
- import torch.nn as nn
439
- from torch.nn import functional as F
440
- from einops import rearrange
441
-
442
- try:
443
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
444
- except ImportError:
445
- create_block_mask = None
446
- flex_attention = None
447
- BlockMask = None
448
-
449
- _compiled_flex_attention = None
450
-
451
-
452
- def _get_flex_attention_fn():
453
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
454
-
455
- Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
456
- on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
457
- on older hardware.
458
- """
459
- global _compiled_flex_attention
460
- if flex_attention is None:
461
- return None
462
- flex_mod = torch.nn.attention.flex_attention
463
- if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
464
- return flex_attention
465
- if _compiled_flex_attention is None:
466
- _compiled_flex_attention = torch.compile(
467
- partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
468
- dynamic=False,
469
- )
470
- return _compiled_flex_attention
471
-
472
-
473
- ### Kernels Flash Attention Detection
474
- def _infer_kernels_flash_variant(kernel) -> Optional[str]:
475
- if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
476
- return "flash_attn2"
477
- if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
478
- return "flash_attn3"
479
- return None
480
-
481
-
482
- def _try_get_kernels_flash():
483
- try:
484
- from kernels import get_kernel
485
- except ImportError:
486
- return None, None
487
-
488
- flash_kernel = None
489
- flash_kernel_variant = None
490
- try:
491
- flash_kernel = get_kernel("kernels-community/flash-attn3")
492
- flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
493
- assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
494
- except Exception:
495
- try:
496
- flash_kernel = get_kernel("kernels-community/flash-attn2")
497
- flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
498
- assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
499
- except Exception:
500
- flash_kernel = None
501
- flash_kernel_variant = None
502
- return flash_kernel, flash_kernel_variant
503
-
504
-
505
- _FLASH_KERNELS_LOADED = False
506
- FLASH_KERNEL = None
507
- FLASH_KERNEL_VARIANT = None
508
-
509
-
510
- def _ensure_flash_kernels_loaded():
511
- global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
512
- if _FLASH_KERNELS_LOADED:
513
- return
514
- _FLASH_KERNELS_LOADED = True
515
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
516
-
517
-
518
- def _kernels_flash_forward(
519
- query_states: torch.Tensor,
520
- key_states: torch.Tensor,
521
- value_states: torch.Tensor,
522
- causal: bool = False,
523
- ) -> torch.Tensor:
524
- assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
525
- if FLASH_KERNEL_VARIANT == "flash_attn2":
526
- return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
527
- if FLASH_KERNEL_VARIANT == "flash_attn3":
528
- try:
529
- output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
530
- except TypeError:
531
- output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
532
- if isinstance(output, tuple):
533
- return output[0]
534
- return output
535
- raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
536
-
537
-
538
- def _kernels_flash_varlen_forward(
539
- query_states: torch.Tensor,
540
- key_states: torch.Tensor,
541
- value_states: torch.Tensor,
542
- cu_seqlens_q: torch.Tensor,
543
- cu_seqlens_k: torch.Tensor,
544
- max_seqlen_in_batch_q: int,
545
- max_seqlen_in_batch_k: int,
546
- causal: bool = False,
547
- ) -> torch.Tensor:
548
- assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
549
- if FLASH_KERNEL_VARIANT == "flash_attn2":
550
- return FLASH_KERNEL.varlen_fwd(
551
- q=query_states, k=key_states, v=value_states,
552
- cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
553
- max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
554
- is_causal=causal,
555
- )[0]
556
- if FLASH_KERNEL_VARIANT == "flash_attn3":
557
- try:
558
- output = FLASH_KERNEL.flash_attn_varlen_func(
559
- q=query_states, k=key_states, v=value_states,
560
- cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
561
- max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
562
- causal=causal,
563
- )
564
- except TypeError:
565
- output = FLASH_KERNEL.flash_attn_varlen_func(
566
- query_states, key_states, value_states,
567
- cu_seqlens_q, cu_seqlens_k,
568
- max_seqlen_in_batch_q, max_seqlen_in_batch_k,
569
- 0.0, None, causal,
570
- )
571
- if isinstance(output, tuple):
572
- return output[0]
573
- return output
574
- raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
575
-
576
-
577
- ### Unpad / Pad helpers for varlen flash attention
578
- class IndexFirstAxis(torch.autograd.Function):
579
- @staticmethod
580
- def forward(ctx, input, indices) -> torch.Tensor:
581
- ctx.save_for_backward(indices)
582
- assert input.ndim >= 2
583
- ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
584
- second_dim = other_shape.numel()
585
- return torch.gather(
586
- rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
587
- ).reshape(-1, *other_shape)
588
-
589
- @staticmethod
590
- def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
591
- (indices,) = ctx.saved_tensors
592
- assert grad_output.ndim >= 2
593
- other_shape = grad_output.shape[1:]
594
- grad_output = rearrange(grad_output, "b ... -> b (...)")
595
- grad_input = torch.zeros(
596
- [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
597
- )
598
- grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
599
- return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
600
-
601
-
602
- class IndexPutFirstAxis(torch.autograd.Function):
603
- @staticmethod
604
- def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
605
- ctx.save_for_backward(indices)
606
- assert indices.ndim == 1
607
- assert values.ndim >= 2
608
- output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
609
- output[indices] = values
610
- return output
611
-
612
- @staticmethod
613
- def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
614
- (indices,) = ctx.saved_tensors
615
- return grad_output[indices], None, None
616
-
617
-
618
- index_first_axis = IndexFirstAxis.apply
619
- index_put_first_axis = IndexPutFirstAxis.apply
620
-
621
-
622
- def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
623
- output = index_put_first_axis(hidden_states, indices, batch * seqlen)
624
- return rearrange(output, "(b s) ... -> b s ...", b=batch)
625
-
626
-
627
- def _unpad_input(
628
- query_layer: torch.Tensor,
629
- key_layer: torch.Tensor,
630
- value_layer: torch.Tensor,
631
- attention_mask_2d: torch.Tensor,
632
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
633
- batch_size, seq_len, num_heads, head_dim = query_layer.shape
634
- seqlens = attention_mask_2d.sum(dim=1).int()
635
- cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
636
- max_seqlen = int(seqlens.max().item())
637
- indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
638
- query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
639
- key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
640
- value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
641
- return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
642
-
643
-
644
- def kernels_flash_attention_func(
645
- query_states: torch.Tensor,
646
- key_states: torch.Tensor,
647
- value_states: torch.Tensor,
648
- attention_mask_2d: Optional[torch.Tensor] = None,
649
- causal: bool = False,
650
- ) -> torch.Tensor:
651
- assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
652
- if not causal and attention_mask_2d is not None:
653
- batch_size, q_len = query_states.shape[:2]
654
- (
655
- query_states, key_states, value_states,
656
- indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
657
- ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
658
- attn_output_unpad = _kernels_flash_varlen_forward(
659
- query_states=query_states, key_states=key_states, value_states=value_states,
660
- cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
661
- max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
662
- )
663
- return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
664
- else:
665
- return _kernels_flash_forward(
666
- query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
667
- )
668
-
669
-
670
- ### Attention Backend Enum & Resolution
671
- class AttentionBackend(Enum):
672
- AUTO = "auto"
673
- KERNELS_FLASH = "kernels_flash"
674
- FLEX = "flex"
675
- SDPA = "sdpa"
676
-
677
-
678
- VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
679
-
680
-
681
- _BACKEND_CONFIRMED = False
682
-
683
-
684
- def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
685
- global _BACKEND_CONFIRMED
686
- assert requested_backend in VALID_ATTENTION_BACKENDS, (
687
- f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
688
- )
689
- if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
690
- _ensure_flash_kernels_loaded()
691
- if requested_backend == AttentionBackend.AUTO.value:
692
- if FLASH_KERNEL is not None:
693
- resolved = AttentionBackend.KERNELS_FLASH
694
- elif flex_attention is not None:
695
- resolved = AttentionBackend.FLEX
696
- else:
697
- resolved = AttentionBackend.SDPA
698
- elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
699
- assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
700
- resolved = AttentionBackend.KERNELS_FLASH
701
- elif requested_backend == AttentionBackend.FLEX.value:
702
- assert flex_attention is not None, "Flex Attention is not available in this environment."
703
- resolved = AttentionBackend.FLEX
704
- elif requested_backend == AttentionBackend.SDPA.value:
705
- resolved = AttentionBackend.SDPA
706
- else:
707
- raise AssertionError(f"Unsupported attention backend: {requested_backend}")
708
- if not _BACKEND_CONFIRMED:
709
- print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
710
- _BACKEND_CONFIRMED = True
711
- return resolved
712
-
713
-
714
- @torch.compiler.disable
715
- def get_attention_mask(
716
- effective_backend: AttentionBackend,
717
- batch_size: int,
718
- seq_len: int,
719
- device: torch.device,
720
- attention_mask: Optional[torch.Tensor] = None,
721
- ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
722
- """Build padding masks once for all encoder layers.
723
-
724
- Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
725
- """
726
- if attention_mask is None:
727
- return None, None, None
728
-
729
- attention_mask_2d = attention_mask.bool()
730
-
731
- if effective_backend == AttentionBackend.KERNELS_FLASH:
732
- return attention_mask_2d, None, None
733
-
734
- if effective_backend == AttentionBackend.FLEX:
735
- assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
736
- valid_lens = attention_mask_2d.sum(dim=-1)
737
-
738
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
739
- return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
740
-
741
- flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
742
- return attention_mask_2d, None, flex_block_mask
743
-
744
- # SDPA / manual -- only mask the key dimension so padding query positions attend to
745
- # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
746
- attention_mask_4d = attention_mask_2d[:, None, None, :]
747
- return attention_mask_2d, attention_mask_4d, None
748
-
749
  import math
750
 
751
  import torch
@@ -756,6 +10,16 @@ from dataclasses import dataclass
756
  from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
757
  from transformers.modeling_outputs import ModelOutput
758
 
 
 
 
 
 
 
 
 
 
 
759
 
760
 
761
  # ---------------------------------------------------------------------------
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import math
4
 
5
  import torch
 
10
  from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
11
  from transformers.modeling_outputs import ModelOutput
12
 
13
+ try:
14
+ from fastplms.attention import (
15
+ AttentionBackend, VALID_ATTENTION_BACKENDS,
16
+ resolve_attention_backend, get_attention_mask,
17
+ _get_flex_attention_fn,
18
+ create_block_mask, flex_attention, BlockMask,
19
+ )
20
+ from fastplms.embedding_mixin import Pooler, EmbeddingMixin, ProteinDataset, parse_fasta, build_collator
21
+ except ImportError:
22
+ pass # Running as HF Hub composite; shared definitions are above
23
 
24
 
25
  # ---------------------------------------------------------------------------