diff --git "a/modeling_ankh.py" "b/modeling_ankh.py" --- "a/modeling_ankh.py" +++ "b/modeling_ankh.py" @@ -1,1576 +1,1570 @@ -from __future__ import annotations - -import torch -import torch._inductor.config as inductor_config -import torch._dynamo as dynamo - -# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) -# Provides significant speedup with minimal precision loss -torch.set_float32_matmul_precision('high') - -# Enable TF32 for matrix multiplications and cuDNN operations -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -# Enable cuDNN autotuner - finds fastest algorithms for your hardware -# Best when input sizes are consistent; may slow down first iterations -torch.backends.cudnn.benchmark = True - -# Deterministic operations off for speed (set True if reproducibility needed) -torch.backends.cudnn.deterministic = False -inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" - -dynamo.config.capture_scalar_outputs = True -torch._dynamo.config.recompile_limit = 16 - -import os -import sqlite3 -import networkx as nx -import numpy as np -import torch -from tqdm.auto import tqdm -from typing import Callable, Dict, List, Optional, Set -from torch.utils.data import DataLoader -from torch.utils.data import Dataset as TorchDataset -from transformers import PreTrainedTokenizerBase - - -class Pooler: - def __init__(self, pooling_types: List[str]) -> None: - self.pooling_types = pooling_types - self.pooling_options: Dict[str, Callable] = { - 'mean': self.mean_pooling, - 'max': self.max_pooling, - 'norm': self.norm_pooling, - 'median': self.median_pooling, - 'std': self.std_pooling, - 'var': self.var_pooling, - 'cls': self.cls_pooling, - 'parti': self._pool_parti, - } - - def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: - assert isinstance(attentions, torch.Tensor) - maxed_attentions = torch.max(attentions, dim=1)[0] - return maxed_attentions - - def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: - # Run PageRank on the attention matrix converted to a graph. - # Raises exceptions if the graph doesn't match the token sequence or has no edges. - # Returns the PageRank scores for each token node. - G = self._convert_to_graph(attention_matrix) - if G.number_of_nodes() != attention_matrix.shape[0]: - raise Exception( - f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") - if G.number_of_edges() == 0: - raise Exception(f"You don't seem to have any attention edges left in the graph.") - - return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) - - def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: - # Convert a matrix (e.g., attention scores) to a directed graph using networkx. - # Each element in the matrix represents a directed edge with a weight. - G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) - return G - - def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: - # Remove keys where attention_mask is 0 - if attention_mask is not None: - for k in list(dict_importance.keys()): - if attention_mask[k] == 0: - del dict_importance[k] - - #dict_importance[0] # remove cls - #dict_importance[-1] # remove eos - total = sum(dict_importance.values()) - return np.array([v / total for _, v in dict_importance.items()]) - - def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d) - maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() - # emb is (b, L, d), maxed_attentions is (b, L, L) - emb_pooled = [] - for e, a, mask in zip(emb, maxed_attentions, attention_mask): - dict_importance = self._page_rank(a) - importance_weights = self._calculate_importance_weights(dict_importance, mask) - num_tokens = int(mask.sum().item()) - emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) - pooled = torch.tensor(np.array(emb_pooled)) - return pooled - - def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.mean(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - - def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.max(dim=1).values - else: - mask = attention_mask.unsqueeze(-1).bool() - return emb.masked_fill(~mask, float('-inf')).max(dim=1).values - - def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.norm(dim=1, p=2) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).norm(dim=1, p=2) - - def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.median(dim=1).values - else: - mask = attention_mask.unsqueeze(-1).bool() - return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values - - def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.std(dim=1) - else: - # Compute variance correctly over non-masked positions, then take sqrt - var = self.var_pooling(emb, attention_mask, **kwargs) - return torch.sqrt(var) - - def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.var(dim=1) - else: - # Correctly compute variance over only non-masked positions - attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) - # Compute mean over non-masked positions - mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - mean = mean.unsqueeze(1) # (b, 1, d) - # Compute squared differences from mean, only over non-masked positions - squared_diff = (emb - mean) ** 2 # (b, L, d) - # Sum squared differences over non-masked positions and divide by count - var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - return var - - def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) - return emb[:, 0, :] - - def __call__( - self, - emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attentions: Optional[torch.Tensor] = None - ) -> torch.Tensor: # [mean, max] - final_emb: List[torch.Tensor] = [] - for pooling_type in self.pooling_types: - final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) - return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) - - -class ProteinDataset(TorchDataset): - """Simple dataset for protein sequences.""" - def __init__(self, sequences: List[str]) -> None: - self.sequences = sequences - - def __len__(self) -> int: - return len(self.sequences) - - def __getitem__(self, idx: int) -> str: - return self.sequences[idx] - - -def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]: - def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: - return tokenizer(sequences, return_tensors="pt", padding='longest') - return _collate_fn - - -def parse_fasta(fasta_path: str) -> List[str]: - assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" - sequences = [] - current_seq = [] - with open(fasta_path, 'r') as f: - for line in f: - line = line.strip() - if not line: - continue - if line.startswith('>'): - if current_seq: - sequences.append(''.join(current_seq)) - current_seq = [] - else: - current_seq.append(line) - if current_seq: - sequences.append(''.join(current_seq)) - return sequences - - -class EmbeddingMixin: - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return next(self.parameters()).device - - def _read_sequences_from_db(self, db_path: str) -> Set[str]: - """Read sequences from SQLite database.""" - sequences = [] - with sqlite3.connect(db_path) as conn: - c = conn.cursor() - c.execute("SELECT sequence FROM embeddings") - while True: - row = c.fetchone() - if row is None: - break - sequences.append(row[0]) - return set(sequences) - - def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: - cursor = conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS embeddings (" - "sequence TEXT PRIMARY KEY, " - "embedding BLOB NOT NULL, " - "shape TEXT, " - "dtype TEXT" - ")" - ) - cursor.execute("PRAGMA table_info(embeddings)") - rows = cursor.fetchall() - column_names = [row[1] for row in rows] - if "shape" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") - if "dtype" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") - conn.commit() - - def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: - assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" - payload = torch.load(save_path, map_location="cpu", weights_only=True) - assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." - for sequence, tensor in payload.items(): - assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." - assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." - return payload - - def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: - assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" - loaded: Dict[str, torch.Tensor] = {} - with sqlite3.connect(db_path) as conn: - self._ensure_embeddings_table(conn) - cursor = conn.cursor() - if sequences is None: - cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") - else: - if len(sequences) == 0: - return loaded - placeholders = ",".join(["?"] * len(sequences)) - cursor.execute( - f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", - tuple(sequences), - ) - - rows = cursor.fetchall() - for row in rows: - sequence = row[0] - embedding_bytes = row[1] - shape_text = row[2] - dtype_text = row[3] - assert shape_text is not None, "Missing shape metadata in embeddings table." - assert dtype_text is not None, "Missing dtype metadata in embeddings table." - shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] - assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" - expected_size = int(np.prod(shape_values)) - np_dtype = np.dtype(dtype_text) - array = np.frombuffer(embedding_bytes, dtype=np_dtype) - assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" - reshaped = array.copy().reshape(tuple(shape_values)) - loaded[sequence] = torch.from_numpy(reshaped) - return loaded - - def embed_dataset( - self, - sequences: Optional[List[str]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - batch_size: int = 2, - max_len: int = 512, - truncate: bool = True, - full_embeddings: bool = False, - embed_dtype: torch.dtype = torch.float32, - pooling_types: List[str] = ['mean'], - num_workers: int = 0, - sql: bool = False, - save: bool = True, - sql_db_path: str = 'embeddings.db', - save_path: str = 'embeddings.pth', - fasta_path: Optional[str] = None, - **kwargs, - ) -> Optional[Dict[str, torch.Tensor]]: - """ - Embed a dataset of protein sequences. - - Supports two modes: - - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. - - Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via - `fasta_path`, or both (the two sources are combined). At least one must be provided. - """ - if fasta_path is not None: - fasta_sequences = parse_fasta(fasta_path) - sequences = list(sequences or []) + fasta_sequences - assert sequences is not None and len(sequences) > 0, \ - "Must provide at least one sequence via `sequences` or `fasta_path`." - sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) - sequences = sorted(sequences, key=len, reverse=True) - hidden_size = self.config.hidden_size - pooler = Pooler(pooling_types) if not full_embeddings else None - tokenizer_mode = tokenizer is not None - if tokenizer_mode: - collate_fn = build_collator(tokenizer) - device = self.device - else: - collate_fn = None - device = None - - def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - assert isinstance(residue_embeddings, torch.Tensor) - if full_embeddings or residue_embeddings.ndim == 2: - return residue_embeddings - return pooler(residue_embeddings, attention_mask) - - def iter_batches(to_embed: List[str]): - if tokenizer_mode: - assert collate_fn is not None - assert device is not None - dataset = ProteinDataset(to_embed) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) - for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): - seqs = to_embed[i * batch_size:(i + 1) * batch_size] - input_ids = batch['input_ids'].to(device) - attention_mask = batch['attention_mask'].to(device) - residue_embeddings = self._embed(input_ids, attention_mask) - yield seqs, residue_embeddings, attention_mask - else: - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) - assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." - assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." - residue_embeddings, attention_mask = batch_output - assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." - yield seqs, residue_embeddings, attention_mask - - if sql: - conn = sqlite3.connect(sql_db_path) - self._ensure_embeddings_table(conn) - c = conn.cursor() - already_embedded = self._read_sequences_from_db(sql_db_path) - to_embed = [seq for seq in sequences if seq not in already_embedded] - print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") - print(f"Embedding {len(to_embed)} new sequences") - if len(to_embed) > 0: - with torch.no_grad(): - for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - emb_np = emb.cpu().numpy() - emb_shape = ",".join([str(dim) for dim in emb_np.shape]) - emb_dtype = str(emb_np.dtype) - c.execute( - "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", - (seq, emb_np.tobytes(), emb_shape, emb_dtype), - ) - if tokenizer_mode and (i + 1) % 100 == 0: - conn.commit() - conn.commit() - conn.close() - return None - - embeddings_dict = {} - if os.path.exists(save_path): - embeddings_dict = self.load_embeddings_from_pth(save_path) - to_embed = [seq for seq in sequences if seq not in embeddings_dict] - print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") - print(f"Embedding {len(to_embed)} new sequences") - else: - to_embed = sequences - print(f"Embedding {len(to_embed)} new sequences") - - if len(to_embed) > 0: - with torch.no_grad(): - for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - embeddings_dict[seq] = emb.cpu() - - if save: - torch.save(embeddings_dict, save_path) - - return embeddings_dict - - -if __name__ == "__main__": - # py -m pooler - pooler = Pooler(pooling_types=['max', 'parti']) - batch_size = 8 - seq_len = 64 - hidden_size = 128 - num_layers = 12 - emb = torch.randn(batch_size, seq_len, hidden_size) - attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) - attention_mask = torch.ones(batch_size, seq_len) - y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) - print(y.shape) - -"""Shared attention infrastructure for all FastPLMs models. - -Contains: AttentionBackend enum, backend resolution, mask creation, -flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. -""" -from enum import Enum -from functools import partial -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange - -try: - from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask -except ImportError: - create_block_mask = None - flex_attention = None - BlockMask = None - -_compiled_flex_attention = None - - -def _get_flex_attention_fn(): - """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set. - - Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4) - on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton - on older hardware. - """ - global _compiled_flex_attention - if flex_attention is None: - return None - flex_mod = torch.nn.attention.flex_attention - if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): - return flex_attention - if _compiled_flex_attention is None: - _compiled_flex_attention = torch.compile( - partial(flex_attention, kernel_options={"BACKEND": "FLASH"}), - dynamic=False, - ) - return _compiled_flex_attention - - -### Kernels Flash Attention Detection -def _infer_kernels_flash_variant(kernel) -> Optional[str]: - if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): - return "flash_attn2" - if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): - return "flash_attn3" - return None - - -def _try_get_kernels_flash(): - try: - from kernels import get_kernel - except ImportError: - return None, None - - flash_kernel = None - flash_kernel_variant = None - try: - flash_kernel = get_kernel("kernels-community/flash-attn3") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." - except Exception: - try: - flash_kernel = get_kernel("kernels-community/flash-attn2") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." - except Exception: - flash_kernel = None - flash_kernel_variant = None - return flash_kernel, flash_kernel_variant - - -_FLASH_KERNELS_LOADED = False -FLASH_KERNEL = None -FLASH_KERNEL_VARIANT = None - - -def _ensure_flash_kernels_loaded(): - global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT - if _FLASH_KERNELS_LOADED: - return - _FLASH_KERNELS_LOADED = True - FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() - - -def _kernels_flash_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) - except TypeError: - output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -def _kernels_flash_varlen_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_in_batch_q: int, - max_seqlen_in_batch_k: int, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.varlen_fwd( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - is_causal=causal, - )[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_varlen_func( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - causal=causal, - ) - except TypeError: - output = FLASH_KERNEL.flash_attn_varlen_func( - query_states, key_states, value_states, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_in_batch_q, max_seqlen_in_batch_k, - 0.0, None, causal, - ) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -### Unpad / Pad helpers for varlen flash attention -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices) -> torch.Tensor: - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype - ) - grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - output[indices] = values - return output - - @staticmethod - def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: - (indices,) = ctx.saved_tensors - return grad_output[indices], None, None - - -index_first_axis = IndexFirstAxis.apply -index_put_first_axis = IndexPutFirstAxis.apply - - -def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -def _unpad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask_2d: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: - batch_size, seq_len, num_heads, head_dim = query_layer.shape - seqlens = attention_mask_2d.sum(dim=1).int() - cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) - max_seqlen = int(seqlens.max().item()) - indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() - query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) - - -def kernels_flash_attention_func( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if not causal and attention_mask_2d is not None: - batch_size, q_len = query_states.shape[:2] - ( - query_states, key_states, value_states, - indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), - ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) - attn_output_unpad = _kernels_flash_varlen_forward( - query_states=query_states, key_states=key_states, value_states=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, - ) - return pad_input(attn_output_unpad, indices_q, batch_size, q_len) - else: - return _kernels_flash_forward( - query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, - ) - - -### Attention Backend Enum & Resolution -class AttentionBackend(Enum): - AUTO = "auto" - KERNELS_FLASH = "kernels_flash" - FLEX = "flex" - SDPA = "sdpa" - - -VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) - - -_BACKEND_CONFIRMED = False - - -def resolve_attention_backend(requested_backend: str) -> AttentionBackend: - global _BACKEND_CONFIRMED - assert requested_backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): - _ensure_flash_kernels_loaded() - if requested_backend == AttentionBackend.AUTO.value: - if FLASH_KERNEL is not None: - resolved = AttentionBackend.KERNELS_FLASH - elif flex_attention is not None: - resolved = AttentionBackend.FLEX - else: - resolved = AttentionBackend.SDPA - elif requested_backend == AttentionBackend.KERNELS_FLASH.value: - assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." - resolved = AttentionBackend.KERNELS_FLASH - elif requested_backend == AttentionBackend.FLEX.value: - assert flex_attention is not None, "Flex Attention is not available in this environment." - resolved = AttentionBackend.FLEX - elif requested_backend == AttentionBackend.SDPA.value: - resolved = AttentionBackend.SDPA - else: - raise AssertionError(f"Unsupported attention backend: {requested_backend}") - if not _BACKEND_CONFIRMED: - print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") - _BACKEND_CONFIRMED = True - return resolved - - -@torch.compiler.disable -def get_attention_mask( - effective_backend: AttentionBackend, - batch_size: int, - seq_len: int, - device: torch.device, - attention_mask: Optional[torch.Tensor] = None, -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: - """Build padding masks once for all encoder layers. - - Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). - """ - if attention_mask is None: - return None, None, None - - attention_mask_2d = attention_mask.bool() - - if effective_backend == AttentionBackend.KERNELS_FLASH: - return attention_mask_2d, None, None - - if effective_backend == AttentionBackend.FLEX: - assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." - valid_lens = attention_mask_2d.sum(dim=-1) - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) - - flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) - return attention_mask_2d, None, flex_block_mask - - # SDPA / manual -- only mask the key dimension so padding query positions attend to - # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). - attention_mask_4d = attention_mask_2d[:, None, None, :] - return attention_mask_2d, attention_mask_4d, None - -import math - -import torch -import torch.nn as nn -from torch.nn import functional as F -from typing import Optional, Tuple, Dict, Any -from dataclasses import dataclass -from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer -from transformers.modeling_outputs import ModelOutput - - - -# --------------------------------------------------------------------------- -# Output dataclasses -# --------------------------------------------------------------------------- - -@dataclass -class AnkhEncoderOutput(ModelOutput): - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor, ...]] = None - attentions: Optional[Tuple[torch.Tensor, ...]] = None - - -@dataclass -class AnkhMaskedLMOutput(ModelOutput): - loss: Optional[torch.Tensor] = None - logits: Optional[torch.Tensor] = None - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor, ...]] = None - attentions: Optional[Tuple[torch.Tensor, ...]] = None - - -# --------------------------------------------------------------------------- -# Config -# --------------------------------------------------------------------------- - -class FastAnkhConfig(PretrainedConfig): - model_type = "fast_ankh" - attribute_map = {"hidden_size": "d_model"} - - def __init__( - self, - vocab_size: int = 144, - d_model: int = 768, - d_kv: int = 64, - d_ff: int = 3072, - num_heads: int = 12, - num_layers: int = 48, - relative_attention_num_buckets: int = 64, - relative_attention_max_distance: int = 128, - dense_act_fn: str = "gelu_new", - layer_norm_epsilon: float = 1e-6, - initializer_factor: float = 1.0, - pad_token_id: int = 0, - eos_token_id: int = 1, - attn_backend: str = "sdpa", - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) - self.vocab_size = vocab_size - self.d_model = d_model - self.d_kv = d_kv - self.d_ff = d_ff - self.num_heads = num_heads - self.num_layers = num_layers - self.relative_attention_num_buckets = relative_attention_num_buckets - self.relative_attention_max_distance = relative_attention_max_distance - self.dense_act_fn = dense_act_fn - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_factor = initializer_factor - self.tie_word_embeddings = False - self.attn_backend = attn_backend - - def to_dict(self) -> Dict[str, Any]: - output = super().to_dict() - return output - - -# --------------------------------------------------------------------------- -# Submodules -# --------------------------------------------------------------------------- - -class AnkhRMSNorm(nn.Module): - """T5-style RMS layer norm: scales without mean subtraction or bias.""" - - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(self.weight.dtype) - - -def _gelu_new(x: torch.Tensor) -> torch.Tensor: - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) - - -class AnkhGatedFFN(nn.Module): - """T5-style gated feed-forward: activation(wi_0(x)) * wi_1(x) -> wo.""" - - def __init__(self, config: FastAnkhConfig): - super().__init__() - self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) - self.act = F.silu if config.dense_act_fn == "silu" else _gelu_new - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.wo(self.act(self.wi_0(hidden_states)) * self.wi_1(hidden_states)) - - -# --------------------------------------------------------------------------- -# Attention -# --------------------------------------------------------------------------- - -class AnkhSelfAttention(nn.Module): - """T5-style self-attention with relative position bias and multi-backend dispatch. - - Only layer 0 has ``has_relative_attention_bias=True`` and owns the - ``nn.Embedding`` that produces the position bias. All other layers - receive the precomputed bias through the forward call. - """ - - def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): - super().__init__() - self.num_heads = config.num_heads - self.d_kv = config.d_kv - self.inner_dim = self.num_heads * self.d_kv - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - - self.q = nn.Linear(config.d_model, self.inner_dim, bias=False) - self.k = nn.Linear(config.d_model, self.inner_dim, bias=False) - self.v = nn.Linear(config.d_model, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, config.d_model, bias=False) - self.scale = self.d_kv ** -0.5 - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - config.relative_attention_num_buckets, config.num_heads - ) - - self.attn_backend: AttentionBackend = AttentionBackend.SDPA # set by encoder - - # ---- T5 relative position bucketing ---- - - @staticmethod - def _relative_position_bucket( - relative_position: torch.Tensor, - num_buckets: int = 32, - max_distance: int = 128, - ) -> torch.Tensor: - """Bidirectional log-bucketed relative position mapping (T5 style).""" - # Bidirectional: half buckets for negative, half for positive - num_buckets //= 2 - relative_buckets = (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.clamp(relative_position_if_large, max=num_buckets - 1) - - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length: int, key_length: int, device: torch.device) -> torch.Tensor: - """Compute (1, H, Q, K) position bias tensor for SDPA / manual paths.""" - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position - buckets = self._relative_position_bucket( - relative_position, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias(buckets) # (Q, K, H) - return values.permute(2, 0, 1).unsqueeze(0) # (1, H, Q, K) - - # ---- Forward ---- - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - position_bias: Optional[torch.Tensor] = None, - flex_score_mod=None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """Returns (attn_output, attn_weights_or_none, position_bias).""" - batch_size, seq_length = hidden_states.shape[:2] - hidden_shape = (batch_size, seq_length, self.num_heads, self.d_kv) - - query_BHLD = self.q(hidden_states).view(hidden_shape).transpose(1, 2) - key_BHLD = self.k(hidden_states).view(hidden_shape).transpose(1, 2) - value_BHLD = self.v(hidden_states).view(hidden_shape).transpose(1, 2) - - # Compute position bias on first layer (SDPA/manual only; flex uses score_mod) - if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX: - position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device) - # Fold padding mask into position bias so layers don't need separate mask - if attention_mask_4d is not None: - position_bias = position_bias + attention_mask_4d.masked_fill( - attention_mask_4d.logical_not(), float("-inf") - ) - - if output_attentions: - attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias) - return self.o(attn_output), attn_weights, position_bias - - if self.attn_backend == AttentionBackend.FLEX: - attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod) - elif self.attn_backend == AttentionBackend.SDPA: - attn_output = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, position_bias) - else: - raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}") - - return self.o(attn_output), None, position_bias - - def _sdpa_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - position_bias: Optional[torch.Tensor], - ) -> torch.Tensor: - # SDPA: position_bias is (1, H, Q, K) additive bias (includes padding mask) - context_BHLD = F.scaled_dot_product_attention( - query_BHLD, key_BHLD, value_BHLD, - attn_mask=position_bias, - scale=self.scale, - ) - return context_BHLD.transpose(1, 2).contiguous().view( - query_BHLD.shape[0], -1, self.inner_dim - ) - - def _flex_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - flex_block_mask: Optional[BlockMask], - flex_score_mod, - ) -> torch.Tensor: - assert flex_attention is not None, "Flex attention is not available." - fn = _get_flex_attention_fn() - context_BHLD = fn( - query_BHLD, key_BHLD, value_BHLD, - score_mod=flex_score_mod, - block_mask=flex_block_mask, - scale=self.scale, - ) - return context_BHLD.transpose(1, 2).contiguous().view( - query_BHLD.shape[0], -1, self.inner_dim - ) - - def _manual_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - position_bias: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) * self.scale - if position_bias is not None: - attn_weights = attn_weights + position_bias - attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) - context_BHLD = torch.matmul(attn_weights, value_BHLD) - attn_output = context_BHLD.transpose(1, 2).contiguous().view( - query_BHLD.shape[0], -1, self.inner_dim - ) - return attn_output, attn_weights - - -# --------------------------------------------------------------------------- -# Encoder block & stack (T5-compatible key naming) -# --------------------------------------------------------------------------- - -class AnkhSelfAttentionLayer(nn.Module): - """Wraps AnkhSelfAttention + layer_norm to match T5Block.layer[0] key naming.""" - - def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): - super().__init__() - self.SelfAttention = AnkhSelfAttention(config, has_relative_attention_bias) - self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - position_bias: Optional[torch.Tensor] = None, - flex_score_mod=None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - normed = self.layer_norm(hidden_states) - attn_output, attn_weights, position_bias = self.SelfAttention( - normed, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - position_bias=position_bias, - flex_score_mod=flex_score_mod, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + attn_output - return hidden_states, attn_weights, position_bias - - -class AnkhFFLayer(nn.Module): - """Wraps AnkhGatedFFN + layer_norm to match T5Block.layer[1] key naming.""" - - def __init__(self, config: FastAnkhConfig): - super().__init__() - self.DenseReluDense = AnkhGatedFFN(config) - self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - normed = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.DenseReluDense(normed) - return hidden_states - - -class AnkhBlock(nn.Module): - """Single transformer block with T5-compatible .layer ModuleList naming.""" - - def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): - super().__init__() - self.layer = nn.ModuleList([ - AnkhSelfAttentionLayer(config, has_relative_attention_bias), - AnkhFFLayer(config), - ]) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - position_bias: Optional[torch.Tensor] = None, - flex_score_mod=None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - hidden_states, attn_weights, position_bias = self.layer[0]( - hidden_states, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - position_bias=position_bias, - flex_score_mod=flex_score_mod, - output_attentions=output_attentions, - ) - hidden_states = self.layer[1](hidden_states) - return hidden_states, attn_weights, position_bias - - -# --------------------------------------------------------------------------- -# PreTrainedModel base -# --------------------------------------------------------------------------- - -class AnkhPreTrainedModel(PreTrainedModel): - config_class = FastAnkhConfig - base_model_prefix = "encoder" - supports_gradient_checkpointing = True - _no_split_modules = ["AnkhBlock"] - - @classmethod - def is_remote_code(cls) -> bool: - return True - - @torch.no_grad() - def _init_weights(self, module: nn.Module) -> None: - factor = self.config.initializer_factor - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5)) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor * 1.0) - elif isinstance(module, AnkhRMSNorm): - module.weight.data.fill_(1.0) - - def post_init(self) -> None: - super().post_init() - - def get_output_embeddings(self): - return None - - @property - def attn_backend(self) -> str: - return self.config.attn_backend - - @attn_backend.setter - def attn_backend(self, backend: str) -> None: - assert backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - self.config.attn_backend = backend - resolved = resolve_attention_backend(backend) - if resolved == AttentionBackend.KERNELS_FLASH: - print("ANKH: kernels_flash -> flex/sdpa fallback") - resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA - for module in self.modules(): - if isinstance(module, FAST_ANKH_ENCODER): - module.attention_backend = resolved - elif isinstance(module, AnkhSelfAttention): - module.attn_backend = resolved - - -# --------------------------------------------------------------------------- -# FAST_ANKH_ENCODER (mirrors T5Stack key naming) -# --------------------------------------------------------------------------- - -class FAST_ANKH_ENCODER(AnkhPreTrainedModel, EmbeddingMixin): - """Inner encoder that mirrors T5Stack attribute naming for weight compliance. - - State dict keys: embed_tokens.*, block.{i}.layer.0.SelfAttention.*, - block.{i}.layer.1.DenseReluDense.*, final_layer_norm.*. - """ - - def __init__(self, config: FastAnkhConfig, **kwargs): - AnkhPreTrainedModel.__init__(self, config, **kwargs) - self.config = config - - resolved = resolve_attention_backend(config.attn_backend) - if resolved == AttentionBackend.KERNELS_FLASH: - print("ANKH: kernels_flash not supported (relative position bias); falling back to flex/sdpa") - resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA - self.attention_backend = resolved - - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) - self.block = nn.ModuleList([ - AnkhBlock(config, has_relative_attention_bias=(i == 0)) - for i in range(config.num_layers) - ]) - for blk in self.block: - blk.layer[0].SelfAttention.attn_backend = self.attention_backend - - self.final_layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.gradient_checkpointing = False - self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base") - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @torch.compiler.disable - def _compute_materialized_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: - """Precompute full (Q, K, H) bias tensor for flex score_mod lookup.""" - bias_embedding = self.block[0].layer[0].SelfAttention.relative_attention_bias - context_position = torch.arange(seq_len, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(seq_len, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position - buckets = AnkhSelfAttention._relative_position_bucket( - relative_position, - num_buckets=self.config.relative_attention_num_buckets, - max_distance=self.config.relative_attention_max_distance, - ) - return bias_embedding(buckets) # (Q, K, H) - - def _build_flex_score_mod(self, seq_len: int, device: torch.device): - """Build score_mod closure that reads from materialized bias tensor.""" - bias = self._compute_materialized_bias(seq_len, device) - - def score_mod(score, b, h, q_idx, kv_idx): - return score + bias[q_idx, kv_idx, h] - - return score_mod - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - encoder_output = self._run_encoder(hidden_states, attention_mask=attention_mask) - return encoder_output.last_hidden_state - - def _run_encoder( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_hidden_states: bool = False, - output_attentions: bool = False, - ) -> AnkhEncoderOutput: - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - batch_size, seq_len = hidden_states.shape[:2] - attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( - effective_backend=self.attention_backend, - batch_size=batch_size, - seq_len=seq_len, - device=hidden_states.device, - attention_mask=attention_mask, - ) - - flex_score_mod = None - position_bias = None - if self.attention_backend == AttentionBackend.FLEX: - flex_score_mod = self._build_flex_score_mod(seq_len, hidden_states.device) - - for layer_module in self.block: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - hidden_states, attn_weights, position_bias = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask_2d, - attention_mask_4d, - flex_block_mask, - position_bias, - flex_score_mod, - output_attentions, - ) - else: - hidden_states, attn_weights, position_bias = layer_module( - hidden_states, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - position_bias=position_bias, - flex_score_mod=flex_score_mod, - output_attentions=output_attentions, - ) - - if all_attentions is not None: - all_attentions = all_attentions + (attn_weights,) - - hidden_states = self.final_layer_norm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return AnkhEncoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ) -> AnkhEncoderOutput: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - hidden_states = self.embed_tokens(input_ids) - elif inputs_embeds is not None: - hidden_states = inputs_embeds - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - return self._run_encoder( - hidden_states, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states or False, - output_attentions=output_attentions or False, - ) - - -# --------------------------------------------------------------------------- -# Model classes -# --------------------------------------------------------------------------- - -class FastAnkhModel(AnkhPreTrainedModel, EmbeddingMixin): - """ANKH encoder model for embedding extraction.""" - - def __init__(self, config: FastAnkhConfig, **kwargs): - AnkhPreTrainedModel.__init__(self, config, **kwargs) - self.config = config - self.shared = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = FAST_ANKH_ENCODER(config) - self.post_init() - - @property - def tokenizer(self): - return self.encoder.tokenizer - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def set_input_embeddings(self, value): - self.encoder.embed_tokens = value - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.encoder._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ) -> AnkhEncoderOutput: - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - - -class FastAnkhForMaskedLM(AnkhPreTrainedModel, EmbeddingMixin): - """ANKH encoder with LM head for masked language modeling. - - NOTE: The LM head is initialized from the shared embedding weights but is NOT - tied. The original ANKH models were trained with T5's span corruption objective - using an encoder-decoder architecture. This encoder-only MaskedLM variant is - not pre-trained for standard MLM and requires additional fine-tuning. - """ - - def __init__(self, config: FastAnkhConfig, **kwargs): - AnkhPreTrainedModel.__init__(self, config, **kwargs) - self.config = config - self.shared = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = FAST_ANKH_ENCODER(config) - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - self.loss_fct = nn.CrossEntropyLoss() - self.post_init() - - @property - def tokenizer(self): - return self.encoder.tokenizer - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def set_input_embeddings(self, value): - self.encoder.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.encoder._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ) -> AnkhMaskedLMOutput: - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - sequence_output = outputs.last_hidden_state - logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - - return AnkhMaskedLMOutput( - loss=loss, - logits=logits, - last_hidden_state=sequence_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FastAnkhForSequenceClassification(AnkhPreTrainedModel, EmbeddingMixin): - def __init__(self, config: FastAnkhConfig, **kwargs): - AnkhPreTrainedModel.__init__(self, config, **kwargs) - self.num_labels = config.num_labels - self.config = config - self.shared = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = FAST_ANKH_ENCODER(config) - self.classifier = nn.Linear(config.d_model, config.num_labels) - self.mse = nn.MSELoss() - self.ce = nn.CrossEntropyLoss() - self.bce = nn.BCEWithLogitsLoss() - self.post_init() - - @property - def tokenizer(self): - return self.encoder.tokenizer - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.encoder._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ) -> AnkhMaskedLMOutput: - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - # Pool: mean over non-padding tokens - sequence_output = outputs.last_hidden_state - if attention_mask is not None: - mask = attention_mask.unsqueeze(-1).to(sequence_output.dtype) - pooled = (sequence_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - else: - pooled = sequence_output.mean(dim=1) - logits = self.classifier(pooled) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss = self.mse(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else self.mse(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss = self.bce(logits, labels) - - return AnkhMaskedLMOutput( - loss=loss, - logits=logits, - last_hidden_state=sequence_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FastAnkhForTokenClassification(AnkhPreTrainedModel, EmbeddingMixin): - def __init__(self, config: FastAnkhConfig, **kwargs): - AnkhPreTrainedModel.__init__(self, config, **kwargs) - self.num_labels = config.num_labels - self.shared = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = FAST_ANKH_ENCODER(config) - self.classifier = nn.Linear(config.d_model, config.num_labels) - self.loss_fct = nn.CrossEntropyLoss() - self.post_init() - - @property - def tokenizer(self): - return self.encoder.tokenizer - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.encoder._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - **kwargs, - ) -> AnkhMaskedLMOutput: - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - sequence_output = outputs.last_hidden_state - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return AnkhMaskedLMOutput( - loss=loss, - logits=logits, - last_hidden_state=sequence_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) +from __future__ import annotations + +import torch +import torch._inductor.config as inductor_config +import torch._dynamo as dynamo + +# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) +# Provides significant speedup with minimal precision loss +torch.set_float32_matmul_precision('high') + +# Enable TF32 for matrix multiplications and cuDNN operations +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +# Enable cuDNN autotuner - finds fastest algorithms for your hardware +# Best when input sizes are consistent; may slow down first iterations +torch.backends.cudnn.benchmark = True + +# Deterministic operations off for speed (set True if reproducibility needed) +torch.backends.cudnn.deterministic = False +inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" + +dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.recompile_limit = 16 + +import os +import sqlite3 +import networkx as nx +import numpy as np +import torch +from tqdm.auto import tqdm +from typing import Callable, Dict, List, Optional, Set +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from transformers import PreTrainedTokenizerBase + + +class Pooler: + def __init__(self, pooling_types: List[str]) -> None: + self.pooling_types = pooling_types + self.pooling_options: Dict[str, Callable] = { + 'mean': self.mean_pooling, + 'max': self.max_pooling, + 'norm': self.norm_pooling, + 'median': self.median_pooling, + 'std': self.std_pooling, + 'var': self.var_pooling, + 'cls': self.cls_pooling, + 'parti': self._pool_parti, + } + + def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: + assert isinstance(attentions, torch.Tensor) + maxed_attentions = torch.max(attentions, dim=1)[0] + return maxed_attentions + + def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: + # Run PageRank on the attention matrix converted to a graph. + # Raises exceptions if the graph doesn't match the token sequence or has no edges. + # Returns the PageRank scores for each token node. + G = self._convert_to_graph(attention_matrix) + if G.number_of_nodes() != attention_matrix.shape[0]: + raise Exception( + f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") + if G.number_of_edges() == 0: + raise Exception(f"You don't seem to have any attention edges left in the graph.") + + return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) + + def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: + # Convert a matrix (e.g., attention scores) to a directed graph using networkx. + # Each element in the matrix represents a directed edge with a weight. + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + return G + + def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: + # Remove keys where attention_mask is 0 + if attention_mask is not None: + for k in list(dict_importance.keys()): + if attention_mask[k] == 0: + del dict_importance[k] + + #dict_importance[0] # remove cls + #dict_importance[-1] # remove eos + total = sum(dict_importance.values()) + return np.array([v / total for _, v in dict_importance.items()]) + + def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # (b, L, d) -> (b, d) + maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() + # emb is (b, L, d), maxed_attentions is (b, L, L) + emb_pooled = [] + for e, a, mask in zip(emb, maxed_attentions, attention_mask): + dict_importance = self._page_rank(a) + importance_weights = self._calculate_importance_weights(dict_importance, mask) + num_tokens = int(mask.sum().item()) + emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) + pooled = torch.tensor(np.array(emb_pooled)) + return pooled + + def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.mean(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + + def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.max(dim=1).values + else: + mask = attention_mask.unsqueeze(-1).bool() + return emb.masked_fill(~mask, float('-inf')).max(dim=1).values + + def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.norm(dim=1, p=2) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).norm(dim=1, p=2) + + def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.median(dim=1).values + else: + mask = attention_mask.unsqueeze(-1).bool() + return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values + + def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.std(dim=1) + else: + # Compute variance correctly over non-masked positions, then take sqrt + var = self.var_pooling(emb, attention_mask, **kwargs) + return torch.sqrt(var) + + def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.var(dim=1) + else: + # Correctly compute variance over only non-masked positions + attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) + # Compute mean over non-masked positions + mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + mean = mean.unsqueeze(1) # (b, 1, d) + # Compute squared differences from mean, only over non-masked positions + squared_diff = (emb - mean) ** 2 # (b, L, d) + # Sum squared differences over non-masked positions and divide by count + var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + return var + + def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: # (b, L, d) -> (b, d) + return emb[:, 0, :] + + def __call__( + self, + emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attentions: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [mean, max] + final_emb: List[torch.Tensor] = [] + for pooling_type in self.pooling_types: + final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) + return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) + + +class ProteinDataset(TorchDataset): + """Simple dataset for protein sequences.""" + def __init__(self, sequences: List[str]) -> None: + self.sequences = sequences + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> str: + return self.sequences[idx] + + +def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[List[str]], Dict[str, torch.Tensor]]: + def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: + return tokenizer(sequences, return_tensors="pt", padding='longest') + return _collate_fn + + +def parse_fasta(fasta_path: str) -> List[str]: + assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" + sequences = [] + current_seq = [] + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + if line.startswith('>'): + if current_seq: + sequences.append(''.join(current_seq)) + current_seq = [] + else: + current_seq.append(line) + if current_seq: + sequences.append(''.join(current_seq)) + return sequences + + +class EmbeddingMixin: + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + def _read_sequences_from_db(self, db_path: str) -> Set[str]: + """Read sequences from SQLite database.""" + sequences = [] + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute("SELECT sequence FROM embeddings") + while True: + row = c.fetchone() + if row is None: + break + sequences.append(row[0]) + return set(sequences) + + def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: + cursor = conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS embeddings (" + "sequence TEXT PRIMARY KEY, " + "embedding BLOB NOT NULL, " + "shape TEXT, " + "dtype TEXT" + ")" + ) + cursor.execute("PRAGMA table_info(embeddings)") + rows = cursor.fetchall() + column_names = [row[1] for row in rows] + if "shape" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") + if "dtype" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") + conn.commit() + + def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: + assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" + payload = torch.load(save_path, map_location="cpu", weights_only=True) + assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." + for sequence, tensor in payload.items(): + assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." + assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." + return payload + + def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: + assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" + loaded: Dict[str, torch.Tensor] = {} + with sqlite3.connect(db_path) as conn: + self._ensure_embeddings_table(conn) + cursor = conn.cursor() + if sequences is None: + cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") + else: + if len(sequences) == 0: + return loaded + placeholders = ",".join(["?"] * len(sequences)) + cursor.execute( + f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", + tuple(sequences), + ) + + rows = cursor.fetchall() + for row in rows: + sequence = row[0] + embedding_bytes = row[1] + shape_text = row[2] + dtype_text = row[3] + assert shape_text is not None, "Missing shape metadata in embeddings table." + assert dtype_text is not None, "Missing dtype metadata in embeddings table." + shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] + assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" + expected_size = int(np.prod(shape_values)) + np_dtype = np.dtype(dtype_text) + array = np.frombuffer(embedding_bytes, dtype=np_dtype) + assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" + reshaped = array.copy().reshape(tuple(shape_values)) + loaded[sequence] = torch.from_numpy(reshaped) + return loaded + + def embed_dataset( + self, + sequences: Optional[List[str]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + batch_size: int = 2, + max_len: int = 512, + truncate: bool = True, + full_embeddings: bool = False, + embed_dtype: torch.dtype = torch.float32, + pooling_types: List[str] = ['mean'], + num_workers: int = 0, + sql: bool = False, + save: bool = True, + sql_db_path: str = 'embeddings.db', + save_path: str = 'embeddings.pth', + fasta_path: Optional[str] = None, + **kwargs, + ) -> Optional[Dict[str, torch.Tensor]]: + """ + Embed a dataset of protein sequences. + + Supports two modes: + - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. + - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. + + Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via + `fasta_path`, or both (the two sources are combined). At least one must be provided. + """ + if fasta_path is not None: + fasta_sequences = parse_fasta(fasta_path) + sequences = list(sequences or []) + fasta_sequences + assert sequences is not None and len(sequences) > 0, \ + "Must provide at least one sequence via `sequences` or `fasta_path`." + sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) + sequences = sorted(sequences, key=len, reverse=True) + hidden_size = self.config.hidden_size + pooler = Pooler(pooling_types) if not full_embeddings else None + tokenizer_mode = tokenizer is not None + if tokenizer_mode: + collate_fn = build_collator(tokenizer) + device = self.device + else: + collate_fn = None + device = None + + def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + assert isinstance(residue_embeddings, torch.Tensor) + if full_embeddings or residue_embeddings.ndim == 2: + return residue_embeddings + return pooler(residue_embeddings, attention_mask) + + def iter_batches(to_embed: List[str]): + if tokenizer_mode: + assert collate_fn is not None + assert device is not None + dataset = ProteinDataset(to_embed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) + for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): + seqs = to_embed[i * batch_size:(i + 1) * batch_size] + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + residue_embeddings = self._embed(input_ids, attention_mask) + yield seqs, residue_embeddings, attention_mask + else: + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) + assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." + assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." + residue_embeddings, attention_mask = batch_output + assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." + yield seqs, residue_embeddings, attention_mask + + if sql: + conn = sqlite3.connect(sql_db_path) + self._ensure_embeddings_table(conn) + c = conn.cursor() + already_embedded = self._read_sequences_from_db(sql_db_path) + to_embed = [seq for seq in sequences if seq not in already_embedded] + print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") + print(f"Embedding {len(to_embed)} new sequences") + if len(to_embed) > 0: + with torch.no_grad(): + for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + emb_np = emb.cpu().numpy() + emb_shape = ",".join([str(dim) for dim in emb_np.shape]) + emb_dtype = str(emb_np.dtype) + c.execute( + "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", + (seq, emb_np.tobytes(), emb_shape, emb_dtype), + ) + if tokenizer_mode and (i + 1) % 100 == 0: + conn.commit() + conn.commit() + conn.close() + return None + + embeddings_dict = {} + if os.path.exists(save_path): + embeddings_dict = self.load_embeddings_from_pth(save_path) + to_embed = [seq for seq in sequences if seq not in embeddings_dict] + print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") + print(f"Embedding {len(to_embed)} new sequences") + else: + to_embed = sequences + print(f"Embedding {len(to_embed)} new sequences") + + if len(to_embed) > 0: + with torch.no_grad(): + for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + embeddings_dict[seq] = emb.cpu() + + if save: + torch.save(embeddings_dict, save_path) + + return embeddings_dict + + +if __name__ == "__main__": + # py -m pooler + pooler = Pooler(pooling_types=['max', 'parti']) + batch_size = 8 + seq_len = 64 + hidden_size = 128 + num_layers = 12 + emb = torch.randn(batch_size, seq_len, hidden_size) + attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) + attention_mask = torch.ones(batch_size, seq_len) + y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) + print(y.shape) + +"""Shared attention infrastructure for all FastPLMs models. + +Contains: AttentionBackend enum, backend resolution, mask creation, +flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. +""" +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange + +try: + from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask +except ImportError: + create_block_mask = None + flex_attention = None + BlockMask = None + +_compiled_flex_attention = None + + +def _get_flex_attention_fn(): + """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" + global _compiled_flex_attention + if flex_attention is None: + return None + flex_mod = torch.nn.attention.flex_attention + if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): + return flex_attention + if _compiled_flex_attention is None: + _compiled_flex_attention = torch.compile( + flex_attention, + dynamic=False, + ) + return _compiled_flex_attention + + +### Kernels Flash Attention Detection +def _infer_kernels_flash_variant(kernel) -> Optional[str]: + if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): + return "flash_attn2" + if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): + return "flash_attn3" + return None + + +def _try_get_kernels_flash(): + try: + from kernels import get_kernel + except ImportError: + return None, None + + flash_kernel = None + flash_kernel_variant = None + try: + flash_kernel = get_kernel("kernels-community/flash-attn3") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." + except Exception: + try: + flash_kernel = get_kernel("kernels-community/flash-attn2") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." + except Exception: + flash_kernel = None + flash_kernel_variant = None + return flash_kernel, flash_kernel_variant + + +_FLASH_KERNELS_LOADED = False +FLASH_KERNEL = None +FLASH_KERNEL_VARIANT = None + + +def _ensure_flash_kernels_loaded(): + global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT + if _FLASH_KERNELS_LOADED: + return + _FLASH_KERNELS_LOADED = True + FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() + + +def _kernels_flash_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) + except TypeError: + output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +def _kernels_flash_varlen_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_in_batch_q: int, + max_seqlen_in_batch_k: int, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.varlen_fwd( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + is_causal=causal, + )[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_varlen_func( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + causal=causal, + ) + except TypeError: + output = FLASH_KERNEL.flash_attn_varlen_func( + query_states, key_states, value_states, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_in_batch_q, max_seqlen_in_batch_k, + 0.0, None, causal, + ) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +### Unpad / Pad helpers for varlen flash attention +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices) -> torch.Tensor: + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: + (indices,) = ctx.saved_tensors + return grad_output[indices], None, None + + +index_first_axis = IndexFirstAxis.apply +index_put_first_axis = IndexPutFirstAxis.apply + + +def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def _unpad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask_2d: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + batch_size, seq_len, num_heads, head_dim = query_layer.shape + seqlens = attention_mask_2d.sum(dim=1).int() + cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) + max_seqlen = int(seqlens.max().item()) + indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() + query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) + + +def kernels_flash_attention_func( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if not causal and attention_mask_2d is not None: + batch_size, q_len = query_states.shape[:2] + ( + query_states, key_states, value_states, + indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), + ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) + attn_output_unpad = _kernels_flash_varlen_forward( + query_states=query_states, key_states=key_states, value_states=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, + ) + return pad_input(attn_output_unpad, indices_q, batch_size, q_len) + else: + return _kernels_flash_forward( + query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, + ) + + +### Attention Backend Enum & Resolution +class AttentionBackend(Enum): + AUTO = "auto" + KERNELS_FLASH = "kernels_flash" + FLEX = "flex" + SDPA = "sdpa" + + +VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) + + +_BACKEND_CONFIRMED = False + + +def resolve_attention_backend(requested_backend: str) -> AttentionBackend: + global _BACKEND_CONFIRMED + assert requested_backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): + _ensure_flash_kernels_loaded() + if requested_backend == AttentionBackend.AUTO.value: + if FLASH_KERNEL is not None: + resolved = AttentionBackend.KERNELS_FLASH + elif flex_attention is not None: + resolved = AttentionBackend.FLEX + else: + resolved = AttentionBackend.SDPA + elif requested_backend == AttentionBackend.KERNELS_FLASH.value: + assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." + resolved = AttentionBackend.KERNELS_FLASH + elif requested_backend == AttentionBackend.FLEX.value: + assert flex_attention is not None, "Flex Attention is not available in this environment." + resolved = AttentionBackend.FLEX + elif requested_backend == AttentionBackend.SDPA.value: + resolved = AttentionBackend.SDPA + else: + raise AssertionError(f"Unsupported attention backend: {requested_backend}") + if not _BACKEND_CONFIRMED: + print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") + _BACKEND_CONFIRMED = True + return resolved + + +@torch.compiler.disable +def get_attention_mask( + effective_backend: AttentionBackend, + batch_size: int, + seq_len: int, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: + """Build padding masks once for all encoder layers. + + Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). + """ + if attention_mask is None: + return None, None, None + + attention_mask_2d = attention_mask.bool() + + if effective_backend == AttentionBackend.KERNELS_FLASH: + return attention_mask_2d, None, None + + if effective_backend == AttentionBackend.FLEX: + assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." + valid_lens = attention_mask_2d.sum(dim=-1) + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) + + flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) + return attention_mask_2d, None, flex_block_mask + + # SDPA / manual -- only mask the key dimension so padding query positions attend to + # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). + attention_mask_4d = attention_mask_2d[:, None, None, :] + return attention_mask_2d, attention_mask_4d, None + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing import Optional, Tuple, Dict, Any +from dataclasses import dataclass +from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer +from transformers.modeling_outputs import ModelOutput + + + +# --------------------------------------------------------------------------- +# Output dataclasses +# --------------------------------------------------------------------------- + +@dataclass +class AnkhEncoderOutput(ModelOutput): + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor, ...]] = None + attentions: Optional[Tuple[torch.Tensor, ...]] = None + + +@dataclass +class AnkhMaskedLMOutput(ModelOutput): + loss: Optional[torch.Tensor] = None + logits: Optional[torch.Tensor] = None + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor, ...]] = None + attentions: Optional[Tuple[torch.Tensor, ...]] = None + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class FastAnkhConfig(PretrainedConfig): + model_type = "fast_ankh" + attribute_map = {"hidden_size": "d_model"} + + def __init__( + self, + vocab_size: int = 144, + d_model: int = 768, + d_kv: int = 64, + d_ff: int = 3072, + num_heads: int = 12, + num_layers: int = 48, + relative_attention_num_buckets: int = 64, + relative_attention_max_distance: int = 128, + dense_act_fn: str = "gelu_new", + layer_norm_epsilon: float = 1e-6, + initializer_factor: float = 1.0, + pad_token_id: int = 0, + eos_token_id: int = 1, + attn_backend: str = "sdpa", + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_heads = num_heads + self.num_layers = num_layers + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dense_act_fn = dense_act_fn + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.tie_word_embeddings = False + self.attn_backend = attn_backend + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + return output + + +# --------------------------------------------------------------------------- +# Submodules +# --------------------------------------------------------------------------- + +class AnkhRMSNorm(nn.Module): + """T5-style RMS layer norm: scales without mean subtraction or bias.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(self.weight.dtype) + + +def _gelu_new(x: torch.Tensor) -> torch.Tensor: + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class AnkhGatedFFN(nn.Module): + """T5-style gated feed-forward: activation(wi_0(x)) * wi_1(x) -> wo.""" + + def __init__(self, config: FastAnkhConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.act = F.silu if config.dense_act_fn == "silu" else _gelu_new + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.wo(self.act(self.wi_0(hidden_states)) * self.wi_1(hidden_states)) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + +class AnkhSelfAttention(nn.Module): + """T5-style self-attention with relative position bias and multi-backend dispatch. + + Only layer 0 has ``has_relative_attention_bias=True`` and owns the + ``nn.Embedding`` that produces the position bias. All other layers + receive the precomputed bias through the forward call. + """ + + def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): + super().__init__() + self.num_heads = config.num_heads + self.d_kv = config.d_kv + self.inner_dim = self.num_heads * self.d_kv + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + + self.q = nn.Linear(config.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(config.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(config.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, config.d_model, bias=False) + self.scale = self.d_kv ** -0.5 + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + config.relative_attention_num_buckets, config.num_heads + ) + + self.attn_backend: AttentionBackend = AttentionBackend.SDPA # set by encoder + + # ---- T5 relative position bucketing ---- + + @staticmethod + def _relative_position_bucket( + relative_position: torch.Tensor, + num_buckets: int = 32, + max_distance: int = 128, + ) -> torch.Tensor: + """Bidirectional log-bucketed relative position mapping (T5 style).""" + # Bidirectional: half buckets for negative, half for positive + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.clamp(relative_position_if_large, max=num_buckets - 1) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length: int, key_length: int, device: torch.device) -> torch.Tensor: + """Compute (1, H, Q, K) position bias tensor for SDPA / manual paths.""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position + buckets = self._relative_position_bucket( + relative_position, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(buckets) # (Q, K, H) + return values.permute(2, 0, 1).unsqueeze(0) # (1, H, Q, K) + + # ---- Forward ---- + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + position_bias: Optional[torch.Tensor] = None, + flex_score_mod=None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Returns (attn_output, attn_weights_or_none, position_bias).""" + batch_size, seq_length = hidden_states.shape[:2] + hidden_shape = (batch_size, seq_length, self.num_heads, self.d_kv) + + query_BHLD = self.q(hidden_states).view(hidden_shape).transpose(1, 2) + key_BHLD = self.k(hidden_states).view(hidden_shape).transpose(1, 2) + value_BHLD = self.v(hidden_states).view(hidden_shape).transpose(1, 2) + + # Compute position bias on first layer (SDPA/manual only; flex uses score_mod) + if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX: + position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device) + # Fold padding mask into position bias so layers don't need separate mask + if attention_mask_4d is not None: + position_bias = position_bias + attention_mask_4d.masked_fill( + attention_mask_4d.logical_not(), float("-inf") + ) + + if output_attentions: + attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias) + return self.o(attn_output), attn_weights, position_bias + + if self.attn_backend == AttentionBackend.FLEX: + attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod) + elif self.attn_backend == AttentionBackend.SDPA: + attn_output = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, position_bias) + else: + raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}") + + return self.o(attn_output), None, position_bias + + def _sdpa_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + position_bias: Optional[torch.Tensor], + ) -> torch.Tensor: + # SDPA: position_bias is (1, H, Q, K) additive bias (includes padding mask) + context_BHLD = F.scaled_dot_product_attention( + query_BHLD, key_BHLD, value_BHLD, + attn_mask=position_bias, + scale=self.scale, + ) + return context_BHLD.transpose(1, 2).contiguous().view( + query_BHLD.shape[0], -1, self.inner_dim + ) + + def _flex_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + flex_block_mask: Optional[BlockMask], + flex_score_mod, + ) -> torch.Tensor: + assert flex_attention is not None, "Flex attention is not available." + fn = _get_flex_attention_fn() + context_BHLD = fn( + query_BHLD, key_BHLD, value_BHLD, + score_mod=flex_score_mod, + block_mask=flex_block_mask, + scale=self.scale, + ) + return context_BHLD.transpose(1, 2).contiguous().view( + query_BHLD.shape[0], -1, self.inner_dim + ) + + def _manual_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + position_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) * self.scale + if position_bias is not None: + attn_weights = attn_weights + position_bias + attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) + context_BHLD = torch.matmul(attn_weights, value_BHLD) + attn_output = context_BHLD.transpose(1, 2).contiguous().view( + query_BHLD.shape[0], -1, self.inner_dim + ) + return attn_output, attn_weights + + +# --------------------------------------------------------------------------- +# Encoder block & stack (T5-compatible key naming) +# --------------------------------------------------------------------------- + +class AnkhSelfAttentionLayer(nn.Module): + """Wraps AnkhSelfAttention + layer_norm to match T5Block.layer[0] key naming.""" + + def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): + super().__init__() + self.SelfAttention = AnkhSelfAttention(config, has_relative_attention_bias) + self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + position_bias: Optional[torch.Tensor] = None, + flex_score_mod=None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed = self.layer_norm(hidden_states) + attn_output, attn_weights, position_bias = self.SelfAttention( + normed, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + position_bias=position_bias, + flex_score_mod=flex_score_mod, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_output + return hidden_states, attn_weights, position_bias + + +class AnkhFFLayer(nn.Module): + """Wraps AnkhGatedFFN + layer_norm to match T5Block.layer[1] key naming.""" + + def __init__(self, config: FastAnkhConfig): + super().__init__() + self.DenseReluDense = AnkhGatedFFN(config) + self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normed = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.DenseReluDense(normed) + return hidden_states + + +class AnkhBlock(nn.Module): + """Single transformer block with T5-compatible .layer ModuleList naming.""" + + def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False): + super().__init__() + self.layer = nn.ModuleList([ + AnkhSelfAttentionLayer(config, has_relative_attention_bias), + AnkhFFLayer(config), + ]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + position_bias: Optional[torch.Tensor] = None, + flex_score_mod=None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + hidden_states, attn_weights, position_bias = self.layer[0]( + hidden_states, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + position_bias=position_bias, + flex_score_mod=flex_score_mod, + output_attentions=output_attentions, + ) + hidden_states = self.layer[1](hidden_states) + return hidden_states, attn_weights, position_bias + + +# --------------------------------------------------------------------------- +# PreTrainedModel base +# --------------------------------------------------------------------------- + +class AnkhPreTrainedModel(PreTrainedModel): + config_class = FastAnkhConfig + base_model_prefix = "encoder" + supports_gradient_checkpointing = True + _no_split_modules = ["AnkhBlock"] + + @classmethod + def is_remote_code(cls) -> bool: + return True + + @torch.no_grad() + def _init_weights(self, module: nn.Module) -> None: + factor = self.config.initializer_factor + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5)) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, AnkhRMSNorm): + module.weight.data.fill_(1.0) + + def post_init(self) -> None: + super().post_init() + + def get_output_embeddings(self): + return None + + @property + def attn_backend(self) -> str: + return self.config.attn_backend + + @attn_backend.setter + def attn_backend(self, backend: str) -> None: + assert backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + self.config.attn_backend = backend + resolved = resolve_attention_backend(backend) + if resolved == AttentionBackend.KERNELS_FLASH: + print("ANKH: kernels_flash -> flex/sdpa fallback") + resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA + for module in self.modules(): + if isinstance(module, FAST_ANKH_ENCODER): + module.attention_backend = resolved + elif isinstance(module, AnkhSelfAttention): + module.attn_backend = resolved + + +# --------------------------------------------------------------------------- +# FAST_ANKH_ENCODER (mirrors T5Stack key naming) +# --------------------------------------------------------------------------- + +class FAST_ANKH_ENCODER(AnkhPreTrainedModel, EmbeddingMixin): + """Inner encoder that mirrors T5Stack attribute naming for weight compliance. + + State dict keys: embed_tokens.*, block.{i}.layer.0.SelfAttention.*, + block.{i}.layer.1.DenseReluDense.*, final_layer_norm.*. + """ + + def __init__(self, config: FastAnkhConfig, **kwargs): + AnkhPreTrainedModel.__init__(self, config, **kwargs) + self.config = config + + resolved = resolve_attention_backend(config.attn_backend) + if resolved == AttentionBackend.KERNELS_FLASH: + print("ANKH: kernels_flash not supported (relative position bias); falling back to flex/sdpa") + resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA + self.attention_backend = resolved + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + self.block = nn.ModuleList([ + AnkhBlock(config, has_relative_attention_bias=(i == 0)) + for i in range(config.num_layers) + ]) + for blk in self.block: + blk.layer[0].SelfAttention.attn_backend = self.attention_backend + + self.final_layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.gradient_checkpointing = False + self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base") + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @torch.compiler.disable + def _compute_materialized_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: + """Precompute full (Q, K, H) bias tensor for flex score_mod lookup.""" + bias_embedding = self.block[0].layer[0].SelfAttention.relative_attention_bias + context_position = torch.arange(seq_len, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(seq_len, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position + buckets = AnkhSelfAttention._relative_position_bucket( + relative_position, + num_buckets=self.config.relative_attention_num_buckets, + max_distance=self.config.relative_attention_max_distance, + ) + return bias_embedding(buckets) # (Q, K, H) + + def _build_flex_score_mod(self, seq_len: int, device: torch.device): + """Build score_mod closure that reads from materialized bias tensor.""" + bias = self._compute_materialized_bias(seq_len, device) + + def score_mod(score, b, h, q_idx, kv_idx): + return score + bias[q_idx, kv_idx, h] + + return score_mod + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + encoder_output = self._run_encoder(hidden_states, attention_mask=attention_mask) + return encoder_output.last_hidden_state + + def _run_encoder( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: bool = False, + output_attentions: bool = False, + ) -> AnkhEncoderOutput: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + batch_size, seq_len = hidden_states.shape[:2] + attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( + effective_backend=self.attention_backend, + batch_size=batch_size, + seq_len=seq_len, + device=hidden_states.device, + attention_mask=attention_mask, + ) + + flex_score_mod = None + position_bias = None + if self.attention_backend == AttentionBackend.FLEX: + flex_score_mod = self._build_flex_score_mod(seq_len, hidden_states.device) + + for layer_module in self.block: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attn_weights, position_bias = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask_2d, + attention_mask_4d, + flex_block_mask, + position_bias, + flex_score_mod, + output_attentions, + ) + else: + hidden_states, attn_weights, position_bias = layer_module( + hidden_states, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + position_bias=position_bias, + flex_score_mod=flex_score_mod, + output_attentions=output_attentions, + ) + + if all_attentions is not None: + all_attentions = all_attentions + (attn_weights,) + + hidden_states = self.final_layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return AnkhEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> AnkhEncoderOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + hidden_states = self.embed_tokens(input_ids) + elif inputs_embeds is not None: + hidden_states = inputs_embeds + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + return self._run_encoder( + hidden_states, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states or False, + output_attentions=output_attentions or False, + ) + + +# --------------------------------------------------------------------------- +# Model classes +# --------------------------------------------------------------------------- + +class FastAnkhModel(AnkhPreTrainedModel, EmbeddingMixin): + """ANKH encoder model for embedding extraction.""" + + def __init__(self, config: FastAnkhConfig, **kwargs): + AnkhPreTrainedModel.__init__(self, config, **kwargs) + self.config = config + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = FAST_ANKH_ENCODER(config) + self.post_init() + + @property + def tokenizer(self): + return self.encoder.tokenizer + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, value): + self.encoder.embed_tokens = value + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.encoder._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> AnkhEncoderOutput: + return self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + + +class FastAnkhForMaskedLM(AnkhPreTrainedModel, EmbeddingMixin): + """ANKH encoder with LM head for masked language modeling. + + NOTE: The LM head is initialized from the shared embedding weights but is NOT + tied. The original ANKH models were trained with T5's span corruption objective + using an encoder-decoder architecture. This encoder-only MaskedLM variant is + not pre-trained for standard MLM and requires additional fine-tuning. + """ + + def __init__(self, config: FastAnkhConfig, **kwargs): + AnkhPreTrainedModel.__init__(self, config, **kwargs) + self.config = config + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = FAST_ANKH_ENCODER(config) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss() + self.post_init() + + @property + def tokenizer(self): + return self.encoder.tokenizer + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, value): + self.encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.encoder._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> AnkhMaskedLMOutput: + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + sequence_output = outputs.last_hidden_state + logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + return AnkhMaskedLMOutput( + loss=loss, + logits=logits, + last_hidden_state=sequence_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FastAnkhForSequenceClassification(AnkhPreTrainedModel, EmbeddingMixin): + def __init__(self, config: FastAnkhConfig, **kwargs): + AnkhPreTrainedModel.__init__(self, config, **kwargs) + self.num_labels = config.num_labels + self.config = config + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = FAST_ANKH_ENCODER(config) + self.classifier = nn.Linear(config.d_model, config.num_labels) + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + self.bce = nn.BCEWithLogitsLoss() + self.post_init() + + @property + def tokenizer(self): + return self.encoder.tokenizer + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.encoder._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> AnkhMaskedLMOutput: + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + # Pool: mean over non-padding tokens + sequence_output = outputs.last_hidden_state + if attention_mask is not None: + mask = attention_mask.unsqueeze(-1).to(sequence_output.dtype) + pooled = (sequence_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) + else: + pooled = sequence_output.mean(dim=1) + logits = self.classifier(pooled) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss = self.mse(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else self.mse(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = self.bce(logits, labels) + + return AnkhMaskedLMOutput( + loss=loss, + logits=logits, + last_hidden_state=sequence_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FastAnkhForTokenClassification(AnkhPreTrainedModel, EmbeddingMixin): + def __init__(self, config: FastAnkhConfig, **kwargs): + AnkhPreTrainedModel.__init__(self, config, **kwargs) + self.num_labels = config.num_labels + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = FAST_ANKH_ENCODER(config) + self.classifier = nn.Linear(config.d_model, config.num_labels) + self.loss_fct = nn.CrossEntropyLoss() + self.post_init() + + @property + def tokenizer(self): + return self.encoder.tokenizer + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.encoder._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> AnkhMaskedLMOutput: + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + sequence_output = outputs.last_hidden_state + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return AnkhMaskedLMOutput( + loss=loss, + logits=logits, + last_hidden_state=sequence_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )