diff --git "a/modeling_e1.py" "b/modeling_e1.py" --- "a/modeling_e1.py" +++ "b/modeling_e1.py" @@ -1,2612 +1,2628 @@ -### Embedding Mixin + Pooler -import os -import sqlite3 -import networkx as nx -import numpy as np -import torch -from tqdm.auto import tqdm -from typing import Callable, List, Optional -from torch.utils.data import DataLoader -from torch.utils.data import Dataset as TorchDataset -from transformers import PreTrainedTokenizerBase - - -class Pooler: - def __init__(self, pooling_types: List[str]): - self.pooling_types = pooling_types - self.pooling_options = { - 'mean': self.mean_pooling, - 'max': self.max_pooling, - 'norm': self.norm_pooling, - 'median': self.median_pooling, - 'std': self.std_pooling, - 'var': self.var_pooling, - 'cls': self.cls_pooling, - 'parti': self._pool_parti, - } - - def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: - maxed_attentions = torch.max(attentions, dim=1)[0] - return maxed_attentions - - def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): - # Run PageRank on the attention matrix converted to a graph. - # Raises exceptions if the graph doesn't match the token sequence or has no edges. - # Returns the PageRank scores for each token node. - G = self._convert_to_graph(attention_matrix) - if G.number_of_nodes() != attention_matrix.shape[0]: - raise Exception( - f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") - if G.number_of_edges() == 0: - raise Exception(f"You don't seem to have any attention edges left in the graph.") - - return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) - - def _convert_to_graph(self, matrix): - # Convert a matrix (e.g., attention scores) to a directed graph using networkx. - # Each element in the matrix represents a directed edge with a weight. - G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) - return G - - def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): - # Remove keys where attention_mask is 0 - if attention_mask is not None: - for k in list(dict_importance.keys()): - if attention_mask[k] == 0: - del dict_importance[k] - - #dict_importance[0] # remove cls - #dict_importance[-1] # remove eos - total = sum(dict_importance.values()) - return np.array([v / total for _, v in dict_importance.items()]) - - def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) - maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() - # emb is (b, L, d), maxed_attentions is (b, L, L) - emb_pooled = [] - for e, a, mask in zip(emb, maxed_attentions, attention_mask): - dict_importance = self._page_rank(a) - importance_weights = self._calculate_importance_weights(dict_importance, mask) - num_tokens = int(mask.sum().item()) - emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) - pooled = torch.tensor(np.array(emb_pooled)) - return pooled - - def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.mean(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - - def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.max(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).max(dim=1).values - - def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.norm(dim=1, p=2) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).norm(dim=1, p=2) - - def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.median(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).median(dim=1).values - - def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.std(dim=1) - else: - # Compute variance correctly over non-masked positions, then take sqrt - var = self.var_pooling(emb, attention_mask, **kwargs) - return torch.sqrt(var) - - def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.var(dim=1) - else: - # Correctly compute variance over only non-masked positions - attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) - # Compute mean over non-masked positions - mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - mean = mean.unsqueeze(1) # (b, 1, d) - # Compute squared differences from mean, only over non-masked positions - squared_diff = (emb - mean) ** 2 # (b, L, d) - # Sum squared differences over non-masked positions and divide by count - var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - return var - - def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - return emb[:, 0, :] - - def __call__( - self, - emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attentions: Optional[torch.Tensor] = None - ): # [mean, max] - final_emb = [] - for pooling_type in self.pooling_types: - final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) - return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) - - -class ProteinDataset(TorchDataset): - """Simple dataset for protein sequences.""" - def __init__(self, sequences: list[str]): - self.sequences = sequences - - def __len__(self) -> int: - return len(self.sequences) - - def __getitem__(self, idx: int) -> str: - return self.sequences[idx] - - -def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: - def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: - return tokenizer(sequences, return_tensors="pt", padding='longest') - return _collate_fn - - -def parse_fasta(fasta_path: str) -> List[str]: - assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" - sequences = [] - current_seq = [] - with open(fasta_path, 'r') as f: - for line in f: - line = line.strip() - if not line: - continue - if line.startswith('>'): - if current_seq: - sequences.append(''.join(current_seq)) - current_seq = [] - else: - current_seq.append(line) - if current_seq: - sequences.append(''.join(current_seq)) - return sequences - - -class EmbeddingMixin: - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return next(self.parameters()).device - - def _read_sequences_from_db(self, db_path: str) -> set[str]: - """Read sequences from SQLite database.""" - sequences = [] - with sqlite3.connect(db_path) as conn: - c = conn.cursor() - c.execute("SELECT sequence FROM embeddings") - while True: - row = c.fetchone() - if row is None: - break - sequences.append(row[0]) - return set(sequences) - - def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: - cursor = conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS embeddings (" - "sequence TEXT PRIMARY KEY, " - "embedding BLOB NOT NULL, " - "shape TEXT, " - "dtype TEXT" - ")" - ) - cursor.execute("PRAGMA table_info(embeddings)") - rows = cursor.fetchall() - column_names = [row[1] for row in rows] - if "shape" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") - if "dtype" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") - conn.commit() - - def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: - assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" - payload = torch.load(save_path, map_location="cpu", weights_only=True) - assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." - for sequence, tensor in payload.items(): - assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." - assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." - return payload - - def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: - assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" - loaded: dict[str, torch.Tensor] = {} - with sqlite3.connect(db_path) as conn: - self._ensure_embeddings_table(conn) - cursor = conn.cursor() - if sequences is None: - cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") - else: - if len(sequences) == 0: - return loaded - placeholders = ",".join(["?"] * len(sequences)) - cursor.execute( - f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", - tuple(sequences), - ) - - rows = cursor.fetchall() - for row in rows: - sequence = row[0] - embedding_bytes = row[1] - shape_text = row[2] - dtype_text = row[3] - assert shape_text is not None, "Missing shape metadata in embeddings table." - assert dtype_text is not None, "Missing dtype metadata in embeddings table." - shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] - assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" - expected_size = int(np.prod(shape_values)) - np_dtype = np.dtype(dtype_text) - array = np.frombuffer(embedding_bytes, dtype=np_dtype) - assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" - reshaped = array.copy().reshape(tuple(shape_values)) - loaded[sequence] = torch.from_numpy(reshaped) - return loaded - - def embed_dataset( - self, - sequences: Optional[List[str]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - batch_size: int = 2, - max_len: int = 512, - truncate: bool = True, - full_embeddings: bool = False, - embed_dtype: torch.dtype = torch.float32, - pooling_types: List[str] = ['mean'], - num_workers: int = 0, - sql: bool = False, - save: bool = True, - sql_db_path: str = 'embeddings.db', - save_path: str = 'embeddings.pth', - fasta_path: Optional[str] = None, - **kwargs, - ) -> Optional[dict[str, torch.Tensor]]: - """ - Embed a dataset of protein sequences. - - Supports two modes: - - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. - - Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via - `fasta_path`, or both (the two sources are combined). At least one must be provided. - """ - if fasta_path is not None: - fasta_sequences = parse_fasta(fasta_path) - sequences = list(sequences or []) + fasta_sequences - assert sequences is not None and len(sequences) > 0, \ - "Must provide at least one sequence via `sequences` or `fasta_path`." - sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) - sequences = sorted(sequences, key=len, reverse=True) - hidden_size = self.config.hidden_size - pooler = Pooler(pooling_types) if not full_embeddings else None - tokenizer_mode = tokenizer is not None - if tokenizer_mode: - collate_fn = build_collator(tokenizer) - device = self.device - else: - collate_fn = None - device = None - - def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - if full_embeddings or residue_embeddings.ndim == 2: - return residue_embeddings - return pooler(residue_embeddings, attention_mask) - - def iter_batches(to_embed: List[str]): - if tokenizer_mode: - assert collate_fn is not None - assert device is not None - dataset = ProteinDataset(to_embed) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) - for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): - seqs = to_embed[i * batch_size:(i + 1) * batch_size] - input_ids = batch['input_ids'].to(device) - attention_mask = batch['attention_mask'].to(device) - residue_embeddings = self._embed(input_ids, attention_mask) - yield seqs, residue_embeddings, attention_mask - else: - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) - assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." - assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." - residue_embeddings, attention_mask = batch_output - assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." - yield seqs, residue_embeddings, attention_mask - - if sql: - conn = sqlite3.connect(sql_db_path) - self._ensure_embeddings_table(conn) - c = conn.cursor() - already_embedded = self._read_sequences_from_db(sql_db_path) - to_embed = [seq for seq in sequences if seq not in already_embedded] - print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") - print(f"Embedding {len(to_embed)} new sequences") - if len(to_embed) > 0: - with torch.no_grad(): - for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - emb_np = emb.cpu().numpy() - emb_shape = ",".join([str(dim) for dim in emb_np.shape]) - emb_dtype = str(emb_np.dtype) - c.execute( - "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", - (seq, emb_np.tobytes(), emb_shape, emb_dtype), - ) - if tokenizer_mode and (i + 1) % 100 == 0: - conn.commit() - conn.commit() - conn.close() - return None - - embeddings_dict = {} - if os.path.exists(save_path): - embeddings_dict = self.load_embeddings_from_pth(save_path) - to_embed = [seq for seq in sequences if seq not in embeddings_dict] - print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") - print(f"Embedding {len(to_embed)} new sequences") - else: - to_embed = sequences - print(f"Embedding {len(to_embed)} new sequences") - - if len(to_embed) > 0: - with torch.no_grad(): - for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - embeddings_dict[seq] = emb.cpu() - - if save: - torch.save(embeddings_dict, save_path) - - return embeddings_dict - - -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence -from einops import rearrange, repeat -from enum import Enum -from typing import Any, TypedDict, Callable, List -from dataclasses import dataclass -from tokenizers import Tokenizer -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ModelOutput -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -### Kernels Flash Attention Detection -def _infer_kernels_flash_variant(kernel) -> str | None: - if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): - return "flash_attn2" - if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): - return "flash_attn3" - return None - - -def _try_get_kernels_flash(): - try: - from kernels import get_kernel - except ImportError: - return None, None - - flash_kernel = None - flash_kernel_variant = None - try: - flash_kernel = get_kernel("kernels-community/flash-attn3") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." - except Exception: - try: - flash_kernel = get_kernel("kernels-community/flash-attn2") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." - except Exception: - flash_kernel = None - flash_kernel_variant = None - return flash_kernel, flash_kernel_variant - - -FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() - - -def _kernels_flash_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) - except TypeError: - output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -def _kernels_flash_varlen_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_in_batch_q: int, - max_seqlen_in_batch_k: int, - causal: bool = False, -) -> torch.Tensor: - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.varlen_fwd( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - is_causal=causal, - )[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_varlen_func( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - causal=causal, - ) - except TypeError: - output = FLASH_KERNEL.flash_attn_varlen_func( - query_states, key_states, value_states, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_in_batch_q, max_seqlen_in_batch_k, - 0.0, None, causal, - ) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -from torch.nn.attention.flex_attention import ( - BlockMask, - create_block_mask, - flex_attention, - _create_sparse_block_from_block_mask -) - -try: - from kernels import get_kernel - layer_norm = get_kernel("kernels-community/triton-layer-norm") -except Exception as e: - logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") - layer_norm = None - - -### Attention Backend Enum & Resolution -class AttentionBackend(Enum): - AUTO = "auto" - KERNELS_FLASH = "kernels_flash" - FLEX = "flex" - SDPA = "sdpa" - - -VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) - - -_BACKEND_CONFIRMED = False - - -def resolve_attention_backend(requested_backend: str) -> AttentionBackend: - global _BACKEND_CONFIRMED - assert requested_backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - if requested_backend == AttentionBackend.AUTO.value: - if FLASH_KERNEL is not None: - resolved = AttentionBackend.KERNELS_FLASH - elif flex_attention is not None: - resolved = AttentionBackend.FLEX - else: - resolved = AttentionBackend.SDPA - elif requested_backend == AttentionBackend.KERNELS_FLASH.value: - assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." - resolved = AttentionBackend.KERNELS_FLASH - elif requested_backend == AttentionBackend.FLEX.value: - assert flex_attention is not None, "Flex Attention is not available in this environment." - resolved = AttentionBackend.FLEX - elif requested_backend == AttentionBackend.SDPA.value: - resolved = AttentionBackend.SDPA - else: - raise AssertionError(f"Unsupported attention backend: {requested_backend}") - if not _BACKEND_CONFIRMED: - print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") - _BACKEND_CONFIRMED = True - return resolved - - -def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: - # Assumes sequence_ids is sorted in increasing order for each batch item, except for - # the -1 values, which are used to indicate the padding tokens. - def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] - return ( - (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) - & (sequence_ids[b, q_idx] != -1) - & (sequence_ids[b, kv_idx] != -1) - ) - - batch_size, seqlen = sequence_ids.shape - return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) - - -def create_within_seq_block_mask(sequence_ids: torch.Tensor) -> BlockMask: - def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] - return ( - (sequence_ids[b, q_idx] == sequence_ids[b, kv_idx]) - & (sequence_ids[b, q_idx] != -1) - & (sequence_ids[b, kv_idx] != -1) - ) - - batch_size, seqlen = sequence_ids.shape - return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) - - -def build_within_seq_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: - not_pad = (sequence_ids != -1) - same_seq = sequence_ids.unsqueeze(-1) == sequence_ids.unsqueeze(-2) - valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) - return (same_seq & valid).unsqueeze(1) - - -def build_block_causal_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: - not_pad = (sequence_ids != -1) - causal = sequence_ids.unsqueeze(-1) >= sequence_ids.unsqueeze(-2) - valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) - return (causal & valid).unsqueeze(1) - - -def flex_attention_func( - query_states: torch.Tensor, # (bs, seqlen, nh, hs) - key_states: torch.Tensor, # (bs, seqlen, nkv, hs) - value_states: torch.Tensor, # (bs, seqlen, nkv, hs) - score_mod: Callable | None = None, - block_mask: BlockMask | None = None, -) -> torch.Tensor: - assert flex_attention is not None, "Flex Attention is not available in this environment" - assert score_mod is None, "Score mod is not supported yet" - query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) - key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) - value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) - - outputs = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - score_mod=score_mod, - enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh - ) - - outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) - return outputs - - -def kernels_flash_attention_func( - query_states: torch.Tensor, # (bs, seqlen, nh, hs) - key_states: torch.Tensor, # (bs, seqlen, nkv, hs) - value_states: torch.Tensor, # (bs, seqlen, nkv, hs) - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, - causal: bool = False, -) -> torch.Tensor: # (bs, seqlen, nh, hs) - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - - if not causal: - batch_size, q_len = query_states.shape[0], query_states.shape[1] - ( - query_states, - key_states, - value_states, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) - - attn_output_unpad = _kernels_flash_varlen_forward( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_in_batch_q=max_seqlen_in_batch_q, - max_seqlen_in_batch_k=max_seqlen_in_batch_k, - causal=False, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) - - else: - attn_output = _kernels_flash_forward(query_states, key_states, value_states, causal=True) - - return attn_output - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def] - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( - -1, *other_shape - ) - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def] - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - device = SLEN.device - total_tokens = torch.sum(SLEN) - B = (total_tokens + block_size - 1) // block_size - padding_tokens = B * block_size - total_tokens - SLEN = torch.cat([SLEN, padding_tokens.reshape(1).to(device=device, dtype=SLEN.dtype)], dim=0) - - assert torch.sum(SLEN) == B * block_size - - # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i - cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) - total_tokens = cum[-1].item() - - # Block start/end offsets [start, end) in token index space - block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) - block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) - - # MIN_SEQ_ID[i] = first sequence whose end > block_start - # searchsorted with right=True returns first index where cum > value - MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) - - # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) - # For empty tail beyond total_tokens we already clipped block_ends. - last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token - MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) - - return MIN_SEQ_ID, MAX_SEQ_ID - - -def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) - MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) - - cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) - cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) - overlap = cond1 & cond2 - - cond1 = (MIN_Q == MAX_Q).unsqueeze(1) - cond2 = (MIN_K == MAX_K).unsqueeze(0) - same_seq_in_qk = cond1 & cond2 - - full_blocks = overlap & same_seq_in_qk - partial_blocks = overlap & ~same_seq_in_qk - - return full_blocks, partial_blocks - - -@torch.compiler.disable -def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: - full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) - partial_blocks = partial_blocks[None, None] - full_blocks = full_blocks[None, None] - - q_doc_id = torch.repeat_interleave(SLEN_Q) - k_doc_id = torch.repeat_interleave(SLEN_K) - - def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: - return q_doc_id[q_idx] == k_doc_id[kv_idx] - - total_q_len = q_doc_id.shape[0] - total_k_len = k_doc_id.shape[0] - - return _create_sparse_block_from_block_mask( - (partial_blocks, full_blocks), - doc_mask, - seq_lengths=(total_q_len, total_k_len), - Q_BLOCK_SIZE=128, - KV_BLOCK_SIZE=128, - ) - - -@torch.compiler.disable -def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: - q_doc_id = torch.repeat_interleave(SLEN_Q) - k_doc_id = torch.repeat_interleave(SLEN_K) - - def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: - return q_doc_id[q_idx] == k_doc_id[kv_idx] - - total_q_len = q_doc_id.shape[0] - total_k_len = k_doc_id.shape[0] - - return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) - - -def varlen_flex_attention_func( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, -) -> torch.Tensor: - batch_size, q_len = query_states.shape[0], query_states.shape[1] - ( - query_states, - key_states, - value_states, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) - - query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() - key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() - value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() - - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] - block_mask = block_mask_creator(seqlens_q, seqlens_k) - - attn_output_unpad = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - enable_gqa=query_states.shape[1] != key_states.shape[1], - ) - - attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) - - return attn_output - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def] - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def] - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - non_pad_indices = sequence_ids != -1 - non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() - sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 - sequence_ids = sequence_ids.flatten()[non_pad_indices] - _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return non_pad_indices, cu_seqlens, max_seqlen_in_batch - - -def _unpad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] - assert query_layer.shape[:2] == q_sequence_ids.shape, ( - f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" - ) - assert key_layer.shape[:2] == k_sequence_ids.shape, ( - f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" - ) - assert query_length <= kv_seq_len, ( - f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" - ) - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if torch.equal(q_sequence_ids, k_sequence_ids): - indices_q = indices_k - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - else: - indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) - - query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) - - assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( - f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" - ) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -index_first_axis = IndexFirstAxis.apply -block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask -PAD_TOKEN_ID = 0 - - -def get_tokenizer() -> Tokenizer: - try: - fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") - tokenizer: Tokenizer = Tokenizer.from_file(fname) - except Exception: - print("E1 Tokenizer not found in local directory, downloading from Hugging Face") - from huggingface_hub import hf_hub_download - fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") - tokenizer: Tokenizer = Tokenizer.from_file(fname) - assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( - f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" - ) - - return tokenizer - - -@dataclass -class DataPrepConfig: - max_num_sequences: int = 512 - max_num_positions_within_seq: int = 8192 - remove_X_tokens: bool = False - - -def get_context(sequence: str) -> str | None: - if "," in sequence: - return sequence.rsplit(",", 1)[0] - return None - - -class E1BatchPreparer: - def __init__( - self, - data_prep_config: DataPrepConfig | None = None, - tokenizer: Tokenizer | None = None, - preserve_context_labels: bool = False, - ): - self.tokenizer = tokenizer or get_tokenizer() - self.data_prep_config = data_prep_config or DataPrepConfig() - self.pad_token_id = self.tokenizer.token_to_id("") - self.preserve_context_labels = preserve_context_labels - device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") - self.boundary_token_ids = torch.tensor( - [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device - ).long() - self.mask_token = "?" # nosec - self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) - self.X_token_id = self.tokenizer.token_to_id("X") - self.vocab = self.tokenizer.get_vocab() - - def get_batch_kwargs( # type: ignore[override] - self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False - ) -> dict[str, torch.Tensor | list[str] | list[int]]: - sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] - return self.pad_encodings(sequence_encodings, device, non_blocking) - - def pad_encodings( - self, - sequence_encodings: list[dict[str, torch.Tensor]], - device: torch.device = torch.device("cpu"), - non_blocking: bool = False, - ) -> dict[str, torch.Tensor | list[str] | list[int]]: - non_blocking = non_blocking and device.type == "cuda" - padded_encodings = {} - # Note: We use -1 as the padding value for sequence and position ids because the 0 value - # is a valid value for sequence and position ids. -1 is then used to distinguish valid - # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. - for key, padding_value in { - "input_ids": self.pad_token_id, - "sequence_ids": -1, - "within_seq_position_ids": -1, - "global_position_ids": -1, - "labels": self.pad_token_id, - }.items(): - padded_encodings[key] = pad_sequence( - [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value - ).to(device=device, dtype=torch.long, non_blocking=non_blocking) - - padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] - padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] - - return padded_encodings - - def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: - single_sequences = sequence.split(",") - if len(single_sequences) > self.data_prep_config.max_num_sequences: - raise ValueError( - f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" - " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." - ) - - single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] - - num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] - input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) - labels = torch.cat([x["labels"] for x in single_sequence_encodings]) - - within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) - global_position_ids, ctx_len = [], 0 - for encoding in single_sequence_encodings: - global_position_ids.append(encoding["position_ids"] + ctx_len) - ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) - global_position_ids = torch.cat(global_position_ids) - - sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) - - # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired - context_len = sum(num_tokens[:-1]) - context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) - if not self.preserve_context_labels: - labels[:context_len] = self.pad_token_id - - assert ( - input_ids.shape - == sequence_ids.shape - == within_seq_position_ids.shape - == global_position_ids.shape - == labels.shape - ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" - - assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" - - return { - "input_ids": input_ids, - "sequence_ids": sequence_ids, - "within_seq_position_ids": within_seq_position_ids, - "global_position_ids": global_position_ids, - "labels": labels, - "context": context, - "context_len": context_len, - } - - def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: - if not self.validate_sequence(sequence): - raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") - - if len(sequence) > self.data_prep_config.max_num_positions_within_seq: - raise ValueError( - f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" - ) - - # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` - # but following is faster since our vocabulary is simple. - tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) - position_ids = torch.arange(len(tokens)) - - if self.data_prep_config.remove_X_tokens: - X_positions = torch.where(tokens != self.X_token_id)[0] - tokens = tokens[X_positions] - position_ids = position_ids[X_positions] - - return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} - - def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: - return torch.isin(tokens, self.boundary_token_ids.to(tokens.device)) - - def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: - return tokens == self.mask_token_id - - def validate_sequence(self, sequence: str) -> bool: - assert isinstance(sequence, str), "Sequence must be a string" - sequence = sequence.replace(self.mask_token, "") - return sequence.isalpha() and sequence.isupper() - - -class E1Config(PretrainedConfig): - model_type = "E1" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( # type: ignore - self, - # Model architecture/initialization - vocab_size=None, - hidden_size=4096, - intermediate_size=16384, - gated_mlp=False, - num_hidden_layers=40, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - rms_norm_eps=1e-5, - initializer_range=0.02, - dtype="bfloat16", - gradient_checkpointing=False, - no_ffn_gradient_checkpointing=False, - # Tokenization - pad_token_id=None, - bos_token_id=None, - eos_token_id=None, - tie_word_embeddings=False, - # Attention implementation & rotary positional embeddings - global_attention_every_n_layers=0, - max_num_sequences=512, - max_num_positions_within_seq=8192, - max_num_positions_global=1024 * 128, - rope_theta_within_seq=10000.0, - rope_theta_global=100000.0, - clip_qkv=None, - attn_backend="sdpa", - **kwargs, - ) -> None: - tokenizer = get_tokenizer() - super().__init__( - pad_token_id=tokenizer.token_to_id(""), - bos_token_id=tokenizer.token_to_id(""), - eos_token_id=tokenizer.token_to_id(""), - tie_word_embeddings=tie_word_embeddings, - dtype=dtype, - **kwargs, - ) - - self.hidden_size = hidden_size - if intermediate_size is None: - intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size - self.intermediate_size = intermediate_size - self.gated_mlp = gated_mlp - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.max_num_positions_within_seq = max_num_positions_within_seq - self.max_num_positions_global = max_num_positions_global - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.rope_theta_within_seq = rope_theta_within_seq - self.rope_theta_global = rope_theta_global - self.max_num_sequences = max_num_sequences - assert clip_qkv is None or clip_qkv > 0 - self.clip_qkv = clip_qkv - self.global_attention_every_n_layers = global_attention_every_n_layers - - self.vocab_size = tokenizer.get_vocab_size() - self.gradient_checkpointing = gradient_checkpointing - self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing - self.attn_backend = attn_backend - - if vocab_size is not None: - if vocab_size < self.vocab_size: - logger.warning( - f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." - ) - self.vocab_size = vocab_size - elif vocab_size > self.vocab_size: - logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") - self.vocab_size = vocab_size - if pad_token_id is not None and pad_token_id != self.pad_token_id: - logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") - if bos_token_id is not None and bos_token_id != self.bos_token_id: - logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") - if eos_token_id is not None and eos_token_id != self.eos_token_id: - logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") - - -class DynamicCache: - """ - A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. - It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. - - Args: - key_cache (`list[torch.Tensor]`): The list of key states. - value_cache (`list[torch.Tensor]`): The list of value states. - """ - - def __init__(self) -> None: - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def update( - self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Update the key and value caches in-place, and return the necessary keys and value states. - - Args: - key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] - value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] - layer_idx (`int`): The index of the layer to update. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. - """ - # Lazy initialization - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: int = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 - return layer_seq_length - - def crop(self, max_length: int) -> None: - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - assert max_length > 0, "max_length must be positive" - - if self.get_seq_length() <= max_length: - return - - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] - - def batch_repeat_interleave(self, repeats: int) -> None: - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor) -> None: - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - - -class KVCache: - def __init__(self, cache_size: int = 4) -> None: - self.cache_size = cache_size - self.tensor_input_field_names = [ - "input_ids", - "within_seq_position_ids", - "global_position_ids", - "sequence_ids", - "labels", - ] - self.tensor_output_field_names = ["logits", "embeddings"] - self.cache_dict: dict[str, DynamicCache] = {} - self.cache_queue: list[str] = [] - - def reset(self) -> None: - for k in list(self.cache_dict.keys()): - del self.cache_dict[k] - del self.cache_dict - self.cache_dict = {} - self.cache_queue = [] - - torch.cuda.empty_cache() - - def before_forward(self, batch: dict[str, torch.Tensor]) -> None: - contexts: list[str] | None = batch.get("context", None) - if contexts is None or "context_len" not in batch: - logger.warning_once( - "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." - ) - return - - context_lens: list[int] = list(set(batch["context_len"])) - contexts: list[str] = list(set(contexts)) # type: ignore[no-redef] - if len(contexts) != 1 or len(context_lens) != 1: - logger.warning( - "SingleContextKVCache requires a single context and context length. " - "Multiple contexts or context lengths found in a single batch. Skipping." - ) - return - - batch_size = batch["input_ids"].shape[0] - - unique_context = contexts[0] - unique_context_len = context_lens[0] - batch["use_cache"] = True - - if unique_context not in self.cache_dict: - return - - self.cache_dict[unique_context].batch_repeat_interleave(batch_size) - past_key_values = self.cache_dict[unique_context] - batch["past_key_values"] = past_key_values - - # Remove context from the input fields - for field_name in self.tensor_input_field_names: - if batch.get(field_name, None) is not None: - batch[field_name] = batch[field_name][:, unique_context_len:] - - def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None: - contexts = batch.get("context", None) - context_lens = batch.get("context_len", []) - if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: - return - - assert batch["use_cache"] - unique_context = contexts[0] - unique_context_len = context_lens[0] - - past_key_values = getattr(outputs, "past_key_values", None) - if not isinstance(past_key_values, DynamicCache): - logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") - return - - if "past_key_values" not in batch: - if len(self.cache_queue) == self.cache_size: - last_context = self.cache_queue.pop(0) - if last_context not in self.cache_queue: - del self.cache_dict[last_context] - torch.cuda.empty_cache() - - self.cache_dict[unique_context] = past_key_values - self.cache_queue.append(unique_context) - - # Remove context from the input fields - for field_name in self.tensor_input_field_names: - if field_name in batch and batch[field_name] is not None: - batch[field_name] = batch[field_name][:, unique_context_len:] - - # Remove context from the output fields - for field_name in self.tensor_output_field_names: - if field_name in outputs and outputs[field_name] is not None: - outputs[field_name] = outputs[field_name][:, unique_context_len:] - if "hidden_states" in outputs and outputs["hidden_states"] is not None: - outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] - - self.cache_dict[unique_context].crop(unique_context_len) - self.cache_dict[unique_context].batch_select_indices([0]) - - -class AttentionLayerType(Enum): - WITHIN_SEQ = "within_seq" - GLOBAL = "global" - - -class AttentionArgs(TypedDict, total=False): - within_seq_block_mask: BlockMask | None - block_causal_block_mask: BlockMask | None - within_seq_mask_4d: torch.Tensor | None - block_causal_mask_4d: torch.Tensor | None - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, - num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__( - self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None - ): - super().__init__() - - self.dim = dim - self.base = base - self.max_position_embeddings = max_position_embeddings - inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) - - @staticmethod - def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: - # Different from paper, but it uses a different permutation in order to obtain the same calculation - self.max_seq_len_cached = seq_len - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - angles = torch.outer(t, self.inv_freq.to(device)) - angles = torch.cat((angles, angles), dim=1) - self.register_buffer("cos_cached", angles.cos(), persistent=False) - self.register_buffer("sin_cached", angles.sin(), persistent=False) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - # x: [bsz, seq_len, num_attention_heads, head_size] - device, dtype = q.device, q.dtype - seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len - - if seq_len > self.max_seq_len_cached: - self._set_sin_cos_cache(seq_len=seq_len, device=device) - - # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), - # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). - idxs = position_ids.to(device) - cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] - sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] - - # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is - # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is - # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed - - -class Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper.""" - - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_kv_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_kv_heads - self.max_num_seqs = config.max_num_sequences - self.clip_qkv = config.clip_qkv - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - if self.config.global_attention_every_n_layers > 0: - self.layer_type = ( - AttentionLayerType.GLOBAL - if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 - else AttentionLayerType.WITHIN_SEQ - ) - else: - self.layer_type = AttentionLayerType.WITHIN_SEQ - - self.rope_theta = ( - config.rope_theta_within_seq - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else config.rope_theta_global - ) - self.max_position_embeddings = ( - config.max_num_positions_within_seq - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else config.max_num_positions_global - ) - - self.rotary_emb = RotaryPositionalEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta - ) - - self.attn_backend = resolve_attention_backend(config.attn_backend) - - def prepare_qkv( - self, - hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - past_key_value: DynamicCache | None = None, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, q_len, _ = hidden_states.size() - query_states: torch.Tensor = self.q_proj(hidden_states) - key_states: torch.Tensor = self.k_proj(hidden_states) - val_states: torch.Tensor = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) - val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) - - if self.clip_qkv is not None: - query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) - key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) - val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) - - query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) - - if use_cache and past_key_value is not None: - key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) - - input_dtype = query_states.dtype - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.q_proj.weight.dtype - if input_dtype != target_dtype: - logger.warning_once( - f"The input hidden states seems to be silently casted in {input_dtype}. " - f"This might be because you have upcasted embedding or layer norm layers " - f"in {input_dtype}. We will cast back the input in {target_dtype}." - ) - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - val_states = val_states.to(target_dtype) - - return query_states, key_states, val_states - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - output_s_max: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: - is_cache_prefilled = ( - use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 - ) - - query_states, key_states, val_states = self.prepare_qkv( - hidden_states=hidden_states, - position_ids=within_seq_position_ids - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else global_position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - ) - - attn_output, attn_weights, s_max = self._attn( - query_states=query_states, - key_states=key_states, - val_states=val_states, - sequence_ids=sequence_ids, - attention_args=attention_args, - output_attentions=output_attentions, - output_s_max=output_s_max, - is_cache_prefilled=is_cache_prefilled, - ) - - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value, s_max - - def _attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - output_attentions: bool = False, - output_s_max: bool = False, - is_cache_prefilled: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: - effective_layer_type = self.layer_type - if is_cache_prefilled and self.layer_type == AttentionLayerType.GLOBAL: - effective_layer_type = AttentionLayerType.WITHIN_SEQ - - if output_attentions: - return self._manual_attn( - query_states, key_states, val_states, - sequence_ids=sequence_ids, - attention_args=attention_args, - effective_layer_type=effective_layer_type, - output_s_max=output_s_max, - is_cache_prefilled=is_cache_prefilled, - ) - - if self.attn_backend == AttentionBackend.KERNELS_FLASH: - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - attn_output, attn_weights = self._kernels_flash_attn( - query_states, key_states, val_states, - sequence_ids=sequence_ids, - is_cache_prefilled=is_cache_prefilled, - ) - else: - attn_output, attn_weights = self._flex_attn( - query_states, key_states, val_states, - attention_args=attention_args, - effective_layer_type=effective_layer_type, - ) - elif self.attn_backend == AttentionBackend.FLEX: - attn_output, attn_weights = self._flex_attn( - query_states, key_states, val_states, - attention_args=attention_args, - effective_layer_type=effective_layer_type, - ) - elif self.attn_backend == AttentionBackend.SDPA: - attn_output, attn_weights = self._sdpa_attn( - query_states, key_states, val_states, - sequence_ids=sequence_ids, - attention_args=attention_args, - effective_layer_type=effective_layer_type, - is_cache_prefilled=is_cache_prefilled, - ) - else: - raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") - - s_max = self._compute_s_max(query_states, key_states) if output_s_max else None - return attn_output, attn_weights, s_max - - @torch.no_grad() - def _compute_s_max( - self, - query_states: torch.Tensor, # (B, L, H, D) - key_states: torch.Tensor, # (B, L, Hkv, D) - ) -> list[torch.Tensor]: - query_BHLD = query_states.transpose(1, 2).contiguous() - key_BHLD = key_states.transpose(1, 2).contiguous() - key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) - scale = 1.0 / (self.head_dim ** 0.5) - q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) - k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) - s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * scale - return [s_max_bound[h] for h in range(self.num_heads)] - - def _kernels_flash_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - sequence_ids: torch.Tensor, - is_cache_prefilled: bool = False, - ) -> tuple[torch.Tensor, None]: - bsz, q_len = query_states.shape[0], query_states.shape[1] - _, kv_len = key_states.shape[0], key_states.shape[1] - - if self.layer_type == AttentionLayerType.GLOBAL and not is_cache_prefilled: - q_sequence_ids = sequence_ids - if q_len < kv_len: - first_token_id = sequence_ids[:, 0].unsqueeze(1) - k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) - else: - k_sequence_ids = sequence_ids - else: - if q_len < kv_len: - key_states = key_states[:, -q_len:] - val_states = val_states[:, -q_len:] - q_sequence_ids = k_sequence_ids = sequence_ids - - attn_output = kernels_flash_attention_func( - query_states, key_states, val_states, - q_sequence_ids=q_sequence_ids, - k_sequence_ids=k_sequence_ids, - causal=False, - ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - return attn_output, None - - def _flex_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - attention_args: AttentionArgs | None = None, - effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, - ) -> tuple[torch.Tensor, None]: - bsz, q_len = query_states.shape[0], query_states.shape[1] - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - block_mask = attention_args["within_seq_block_mask"] if attention_args is not None else None - else: - block_mask = attention_args["block_causal_block_mask"] if attention_args is not None else None - outputs = flex_attention_func(query_states, key_states, val_states, block_mask=block_mask) - outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() - return outputs, None - - def _sdpa_attn( - self, - query_states: torch.Tensor, # (B, L, H, D) - key_states: torch.Tensor, # (B, L, Hkv, D) - val_states: torch.Tensor, # (B, L, Hkv, D) - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, - is_cache_prefilled: bool = False, - ) -> tuple[torch.Tensor, None]: - bsz, q_len = query_states.shape[:2] - kv_len = key_states.shape[1] - - if is_cache_prefilled and q_len < kv_len: - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - key_states = key_states[:, -q_len:] - val_states = val_states[:, -q_len:] - attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None - elif attention_args is not None: - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - attention_mask_4d = attention_args["within_seq_mask_4d"] - else: - attention_mask_4d = attention_args["block_causal_mask_4d"] - else: - attention_mask_4d = None - - query_BHLD = query_states.transpose(1, 2).contiguous() - key_BHLD = key_states.transpose(1, 2).contiguous() - val_BHLD = val_states.transpose(1, 2).contiguous() - key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) - val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) - context_BHLD = F.scaled_dot_product_attention(query_BHLD, key_BHLD, val_BHLD, attn_mask=attention_mask_4d) - attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() - return attn_output, None - - def _manual_attn( - self, - query_states: torch.Tensor, # (B, L, H, D) - key_states: torch.Tensor, # (B, L, Hkv, D) - val_states: torch.Tensor, # (B, L, Hkv, D) - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, - output_s_max: bool = False, - is_cache_prefilled: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: - bsz, q_len = query_states.shape[:2] - kv_len = key_states.shape[1] - - if is_cache_prefilled and q_len < kv_len: - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - key_states = key_states[:, -q_len:] - val_states = val_states[:, -q_len:] - attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None - elif attention_args is not None: - if effective_layer_type == AttentionLayerType.WITHIN_SEQ: - attention_mask_4d = attention_args["within_seq_mask_4d"] - else: - attention_mask_4d = attention_args["block_causal_mask_4d"] - else: - attention_mask_4d = None - - query_BHLD = query_states.transpose(1, 2).contiguous() - key_BHLD = key_states.transpose(1, 2).contiguous() - val_BHLD = val_states.transpose(1, 2).contiguous() - key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) - val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) - scale = 1.0 / (self.head_dim ** 0.5) - attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale - if attention_mask_4d is not None: - attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) - attn_weights = F.softmax(attn_weights, dim=-1) - context_BHLD = torch.matmul(attn_weights, val_BHLD) - attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() - s_max = self._compute_s_max(query_states, key_states) if output_s_max else None - return attn_output, attn_weights, s_max - - -class MLP(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.w2(self.act_fn(self.w1(hidden_states))) - - -class GLUMLP(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - hidden_states = self.w2(hidden_states) - return hidden_states - - -class FFN(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - mlp_cls = GLUMLP if config.gated_mlp else MLP - self.mlp = mlp_cls(config) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.mlp(hidden_states) - - -@dataclass -class E1ModelOutputWithPast(ModelOutput): - """Base class for model's outputs, with potential hidden states and attentions. - - Attributes: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - s_max: tuple[list[torch.Tensor], ...] | None = None - - -@dataclass -class E1MaskedLMOutputWithPast(ModelOutput): - loss: torch.FloatTensor | None = None - mlm_loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - s_max: tuple[list[torch.Tensor], ...] | None = None - - -@dataclass -class E1ClassificationOutputWithPast(ModelOutput): - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - s_max: tuple[list[torch.Tensor], ...] | None = None - - -class RMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - self.hidden_size = hidden_size - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - if layer_norm is None: - return torch.nn.functional.rms_norm( - hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon - ).to(input_dtype) - else: - return layer_norm.rms_norm_fn( - x=hidden_states, - weight=self.weight, - bias=None, # no bias - residual=None, - eps=self.variance_epsilon, - dropout_p=0.0, # no dropout by default - prenorm=False, - residual_in_fp32=False, - ).to(input_dtype) - - -class NormAttentionNorm(nn.Module): - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.self_attn = Attention(config, layer_idx) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - output_s_max: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value, s_max = self.self_attn( - hidden_states=hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_value, - output_attentions=output_attentions, - output_s_max=output_s_max, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - return hidden_states, residual, self_attn_weights, present_key_value, s_max - - -class DecoderLayer(nn.Module): - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.initializer_range = config.initializer_range - self.hidden_size = config.hidden_size - self.norm_attn_norm = NormAttentionNorm(config, layer_idx) - self.ffn = FFN(config) - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - output_s_max: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: - hidden_states, residual, self_attn_weights, present_key_value, s_max = self.norm_attn_norm( - hidden_states=hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_value, - output_attentions=output_attentions, - output_s_max=output_s_max, - use_cache=use_cache, - ) - - # Fully Connected - hidden_states = self.ffn(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights, present_key_value, s_max - - -class E1PreTrainedModel(PreTrainedModel): - config_class = E1Config - config: E1Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DecoderLayer"] - _transformer_layer_cls = [DecoderLayer] - _skip_keys_device_placement = "past_key_values" - all_tied_weights_keys = {} - - def _init_weights(self, module: nn.Module) -> None: - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - def _backward_compatibility_gradient_checkpointing(self) -> None: - if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): - self.gradient_checkpointing_enable(dict(use_reentrant=False)) - - def post_init(self) -> None: - super().post_init() - - @property - def _device(self) -> torch.device: - return next(self.parameters()).device - - @property - def attn_backend(self) -> str: - return self.config.attn_backend - - @attn_backend.setter - def attn_backend(self, backend: str) -> None: - assert backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - self.config.attn_backend = backend - resolved = resolve_attention_backend(backend) - for module in self.modules(): - if isinstance(module, FAST_E1_ENCODER): - module._attn_backend = resolved - elif isinstance(module, Attention): - module.attn_backend = resolved - - -class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) - self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = E1BatchPreparer() - self._attn_backend = resolve_attention_backend(config.attn_backend) - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embed_tokens - - def set_input_embeddings(self, value: nn.Embedding) -> None: - self.embed_tokens = value - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - # Ignore copy - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - output_s_max: bool = False, - **kwargs - ) -> E1ModelOutputWithPast: - """ - Args: - input_ids: (batch_size, seq_length) - within_seq_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the sequence itself. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] - global_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the global sequence. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] - sequence_ids: (batch_size, seq_length) - This tensor contains the sequence id of each residue. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] - past_key_values: DynamicCache - use_cache: bool - output_attentions: bool - output_hidden_states: bool - output_s_max: bool - - Returns: - E1ModelOutputWithPast: Model Outputs - """ - batch_size, seq_length = input_ids.shape - - if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - elif not use_cache: - past_key_values = None - - global_position_ids = global_position_ids.view(-1, seq_length).long() - within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() - sequence_ids = sequence_ids.view(-1, seq_length).long() - - max_position_id = torch.max(within_seq_position_ids).item() - min_position_id = torch.min(within_seq_position_ids).item() - assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( - f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" - ) - - inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) - - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype - hidden_states = inputs_embeds.to(target_dtype) - - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - - attn_backend = self._attn_backend - has_global_layers = self.config.global_attention_every_n_layers > 0 - needs_4d_masks = (attn_backend == AttentionBackend.SDPA) or output_attentions - needs_block_causal_flex = ( - (attn_backend == AttentionBackend.FLEX and has_global_layers) - or (attn_backend == AttentionBackend.KERNELS_FLASH and has_global_layers) - ) - needs_within_seq_flex = (attn_backend == AttentionBackend.FLEX) - - attention_args: AttentionArgs | None = None - if past_key_values_length == 0: - attention_args = AttentionArgs( - block_causal_block_mask=create_block_causal_mask_optimized(sequence_ids) if needs_block_causal_flex else None, - within_seq_block_mask=create_within_seq_block_mask(sequence_ids) if needs_within_seq_flex else None, - within_seq_mask_4d=build_within_seq_mask_4d(sequence_ids) if needs_4d_masks else None, - block_causal_mask_4d=build_block_causal_mask_4d(sequence_ids) if needs_4d_masks else None, - ) - - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - full_s_max = () if output_s_max else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore[operator] - - if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - within_seq_position_ids, - global_position_ids, - sequence_ids, - attention_args, - past_key_values, - output_attentions, - output_s_max, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_s_max=output_s_max, - use_cache=use_cache, - ) - - hidden_states, self_attn_weights, present_key_value, s_max = layer_outputs - - if use_cache: - next_decoder_cache = past_key_values = present_key_value - - if output_attentions: - all_self_attns += (self_attn_weights,) # type: ignore[operator] - - if full_s_max is not None: - full_s_max += (s_max,) # type: ignore[operator] - - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore[operator] - - next_cache = next_decoder_cache if use_cache else None - - return E1ModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - s_max=full_s_max, - ) - - -class E1Model(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) - self.prep_tokens = self.model.prep_tokens - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value: nn.Embedding) -> None: - self.model.set_input_embeddings(value) - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs) - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - output_s_max: bool = False, - **kwargs, - ) -> E1ModelOutputWithPast: - return self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - **kwargs, - ) - - -class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) - self.vocab_size = config.vocab_size - self.mlm_head = torch.nn.Sequential( - nn.Linear(config.hidden_size, config.hidden_size, bias=True), - nn.GELU(), - nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), - nn.Linear(config.hidden_size, config.vocab_size, bias=True), - ) - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = self.model.prep_tokens - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - output_s_max: bool = False, - **kwargs, - ) -> E1MaskedLMOutputWithPast: - """ - Args: - input_ids: (batch_size, seq_length) - within_seq_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the sequence itself. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] - global_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the global sequence. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] - sequence_ids: (batch_size, seq_length) - This tensor contains the sequence id of each residue. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] - labels: (batch_size, seq_length) - past_key_values: DynamicCache - use_cache: bool - output_attentions: bool - output_hidden_states: bool - output_s_max: bool - - Returns: - E1MaskedLMOutputWithPast: Model Outputs - """ - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - ) - - last_hidden_state = outputs.last_hidden_state - loss = None - - mlm_logits = self.mlm_head(last_hidden_state).float() - mlm_loss = 0.0 - if labels is not None: - mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) - mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) - mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") - mask = mlm_labels_flat != self.model.padding_idx - n_mlm = mask.sum() - mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) - loss = 0.0 - loss += mlm_loss - - return E1MaskedLMOutputWithPast( - loss=loss, - mlm_loss=mlm_loss, - logits=mlm_logits, - last_hidden_state=last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - s_max=outputs.s_max, - ) - - -class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) - self.vocab_size = config.vocab_size - self.num_labels = config.num_labels - self.classifier = nn.Sequential( - nn.Linear(config.hidden_size * 2, config.hidden_size * 4), - nn.GELU(), - nn.LayerNorm(config.hidden_size * 4), - nn.Linear(config.hidden_size * 4, config.num_labels), - ) - self.mse = nn.MSELoss() - self.ce = nn.CrossEntropyLoss() - self.bce = nn.BCEWithLogitsLoss() - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = self.model.prep_tokens - - if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: - pooling_types = kwargs['pooling_types'] - else: - pooling_types = ['mean', 'var'] - self.pooler = Pooler(pooling_types) - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - output_s_max: bool = False, - **kwargs, - ) -> E1ClassificationOutputWithPast: - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - ) - - attention_mask = (sequence_ids != -1).long() - x = outputs.last_hidden_state - features = self.pooler(x, attention_mask) - logits = self.classifier(features) - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - if self.num_labels == 1: - loss = self.mse(logits.flatten(), labels.flatten()) - else: - loss = self.mse(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss = self.bce(logits, labels) - - return E1ClassificationOutputWithPast( - loss=loss, - logits=logits, - last_hidden_state=x, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - s_max=outputs.s_max, - ) - - -class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) - self.vocab_size = config.vocab_size - self.num_labels = config.num_labels - self.classifier = nn.Sequential( - nn.Linear(config.hidden_size * 2, config.hidden_size * 4), - nn.GELU(), - nn.LayerNorm(config.hidden_size * 4), - nn.Linear(config.hidden_size * 4, config.num_labels), - ) - self.loss_fct = nn.CrossEntropyLoss() - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = self.model.prep_tokens - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - output_s_max: bool = False, - **kwargs, - ) -> E1ClassificationOutputWithPast: - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - ) - - x = outputs.last_hidden_state - logits = self.classifier(x) - loss = None - if labels is not None: - loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return E1ClassificationOutputWithPast( - loss=loss, - logits=logits, - last_hidden_state=x, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - s_max=outputs.s_max, - ) - - -if __name__ == "__main__": - import random - - import torch - - from torch import Tensor - - def print_tensor_shapes(prefix: str, obj): - if isinstance(obj, Tensor): - print(f"{prefix}{obj.shape}") - elif isinstance(obj, dict): - for name, value in obj.items(): - print_tensor_shapes(f"{prefix}{name}.", value) - elif isinstance(obj, list): - for idx, value in enumerate(obj): - print_tensor_shapes(f"{prefix}[{idx}].", value) - elif isinstance(obj, tuple): - for idx, value in enumerate(obj): - print_tensor_shapes(f"{prefix}[{idx}].", value) - elif hasattr(obj, "__dict__"): - for name, value in vars(obj).items(): - if name.startswith("_"): - continue - print_tensor_shapes(f"{prefix}{name}.", value) - else: - print(f"{prefix}{type(obj)}") - - def get_e1_batch(tokenizer, sequences: list[str], device: torch.device): - preparer = E1BatchPreparer(data_prep_config=DataPrepConfig(max_num_positions_within_seq=64), tokenizer=tokenizer) - return preparer.get_batch_kwargs(sequences=sequences, device=device) - - random.seed(0) - torch.manual_seed(0) - - num_attention_heads = random.choice([2, 4]) - config = E1Config( - hidden_size=16 * num_attention_heads, - intermediate_size=64 * num_attention_heads, - num_hidden_layers=random.choice([1, 2]), - num_attention_heads=num_attention_heads, - num_key_value_heads=num_attention_heads, - max_num_positions_within_seq=128, - max_num_positions_global=256, - max_num_sequences=8, - dtype="float32", - ) - model = E1ForMaskedLM(config=config).eval() - tokenizer = get_tokenizer() - batch = get_e1_batch(tokenizer=tokenizer, sequences=["ACDEFG", "MKTW"], device=torch.device("cpu")) - batch["labels"] = batch["labels"].clone() - - with torch.no_grad(): - output = model( - input_ids=batch["input_ids"], - within_seq_position_ids=batch["within_seq_position_ids"], - global_position_ids=batch["global_position_ids"], - sequence_ids=batch["sequence_ids"], - labels=batch["labels"], - ) - - print("Batch shape:") - print_tensor_shapes("", batch) - print("Output shape:") - print_tensor_shapes("", output) - - +### Embedding Mixin + Pooler +import os +import sqlite3 +import networkx as nx +import numpy as np +import torch +from tqdm.auto import tqdm +from typing import Callable, List, Optional +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from transformers import PreTrainedTokenizerBase + + +class Pooler: + def __init__(self, pooling_types: List[str]): + self.pooling_types = pooling_types + self.pooling_options = { + 'mean': self.mean_pooling, + 'max': self.max_pooling, + 'norm': self.norm_pooling, + 'median': self.median_pooling, + 'std': self.std_pooling, + 'var': self.var_pooling, + 'cls': self.cls_pooling, + 'parti': self._pool_parti, + } + + def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: + maxed_attentions = torch.max(attentions, dim=1)[0] + return maxed_attentions + + def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): + # Run PageRank on the attention matrix converted to a graph. + # Raises exceptions if the graph doesn't match the token sequence or has no edges. + # Returns the PageRank scores for each token node. + G = self._convert_to_graph(attention_matrix) + if G.number_of_nodes() != attention_matrix.shape[0]: + raise Exception( + f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") + if G.number_of_edges() == 0: + raise Exception(f"You don't seem to have any attention edges left in the graph.") + + return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) + + def _convert_to_graph(self, matrix): + # Convert a matrix (e.g., attention scores) to a directed graph using networkx. + # Each element in the matrix represents a directed edge with a weight. + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + return G + + def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): + # Remove keys where attention_mask is 0 + if attention_mask is not None: + for k in list(dict_importance.keys()): + if attention_mask[k] == 0: + del dict_importance[k] + + #dict_importance[0] # remove cls + #dict_importance[-1] # remove eos + total = sum(dict_importance.values()) + return np.array([v / total for _, v in dict_importance.items()]) + + def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) + maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() + # emb is (b, L, d), maxed_attentions is (b, L, L) + emb_pooled = [] + for e, a, mask in zip(emb, maxed_attentions, attention_mask): + dict_importance = self._page_rank(a) + importance_weights = self._calculate_importance_weights(dict_importance, mask) + num_tokens = int(mask.sum().item()) + emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) + pooled = torch.tensor(np.array(emb_pooled)) + return pooled + + def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.mean(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + + def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.max(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).max(dim=1).values + + def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.norm(dim=1, p=2) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).norm(dim=1, p=2) + + def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.median(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).median(dim=1).values + + def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.std(dim=1) + else: + # Compute variance correctly over non-masked positions, then take sqrt + var = self.var_pooling(emb, attention_mask, **kwargs) + return torch.sqrt(var) + + def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.var(dim=1) + else: + # Correctly compute variance over only non-masked positions + attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) + # Compute mean over non-masked positions + mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + mean = mean.unsqueeze(1) # (b, 1, d) + # Compute squared differences from mean, only over non-masked positions + squared_diff = (emb - mean) ** 2 # (b, L, d) + # Sum squared differences over non-masked positions and divide by count + var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + return var + + def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + return emb[:, 0, :] + + def __call__( + self, + emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attentions: Optional[torch.Tensor] = None + ): # [mean, max] + final_emb = [] + for pooling_type in self.pooling_types: + final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) + return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) + + +class ProteinDataset(TorchDataset): + """Simple dataset for protein sequences.""" + def __init__(self, sequences: list[str]): + self.sequences = sequences + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> str: + return self.sequences[idx] + + +def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: + def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: + return tokenizer(sequences, return_tensors="pt", padding='longest') + return _collate_fn + + +def parse_fasta(fasta_path: str) -> List[str]: + assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" + sequences = [] + current_seq = [] + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + if line.startswith('>'): + if current_seq: + sequences.append(''.join(current_seq)) + current_seq = [] + else: + current_seq.append(line) + if current_seq: + sequences.append(''.join(current_seq)) + return sequences + + +class EmbeddingMixin: + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + def _read_sequences_from_db(self, db_path: str) -> set[str]: + """Read sequences from SQLite database.""" + sequences = [] + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute("SELECT sequence FROM embeddings") + while True: + row = c.fetchone() + if row is None: + break + sequences.append(row[0]) + return set(sequences) + + def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: + cursor = conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS embeddings (" + "sequence TEXT PRIMARY KEY, " + "embedding BLOB NOT NULL, " + "shape TEXT, " + "dtype TEXT" + ")" + ) + cursor.execute("PRAGMA table_info(embeddings)") + rows = cursor.fetchall() + column_names = [row[1] for row in rows] + if "shape" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") + if "dtype" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") + conn.commit() + + def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: + assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" + payload = torch.load(save_path, map_location="cpu", weights_only=True) + assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." + for sequence, tensor in payload.items(): + assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." + assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." + return payload + + def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: + assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" + loaded: dict[str, torch.Tensor] = {} + with sqlite3.connect(db_path) as conn: + self._ensure_embeddings_table(conn) + cursor = conn.cursor() + if sequences is None: + cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") + else: + if len(sequences) == 0: + return loaded + placeholders = ",".join(["?"] * len(sequences)) + cursor.execute( + f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", + tuple(sequences), + ) + + rows = cursor.fetchall() + for row in rows: + sequence = row[0] + embedding_bytes = row[1] + shape_text = row[2] + dtype_text = row[3] + assert shape_text is not None, "Missing shape metadata in embeddings table." + assert dtype_text is not None, "Missing dtype metadata in embeddings table." + shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] + assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" + expected_size = int(np.prod(shape_values)) + np_dtype = np.dtype(dtype_text) + array = np.frombuffer(embedding_bytes, dtype=np_dtype) + assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" + reshaped = array.copy().reshape(tuple(shape_values)) + loaded[sequence] = torch.from_numpy(reshaped) + return loaded + + def embed_dataset( + self, + sequences: Optional[List[str]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + batch_size: int = 2, + max_len: int = 512, + truncate: bool = True, + full_embeddings: bool = False, + embed_dtype: torch.dtype = torch.float32, + pooling_types: List[str] = ['mean'], + num_workers: int = 0, + sql: bool = False, + save: bool = True, + sql_db_path: str = 'embeddings.db', + save_path: str = 'embeddings.pth', + fasta_path: Optional[str] = None, + **kwargs, + ) -> Optional[dict[str, torch.Tensor]]: + """ + Embed a dataset of protein sequences. + + Supports two modes: + - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. + - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. + + Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via + `fasta_path`, or both (the two sources are combined). At least one must be provided. + """ + if fasta_path is not None: + fasta_sequences = parse_fasta(fasta_path) + sequences = list(sequences or []) + fasta_sequences + assert sequences is not None and len(sequences) > 0, \ + "Must provide at least one sequence via `sequences` or `fasta_path`." + sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) + sequences = sorted(sequences, key=len, reverse=True) + hidden_size = self.config.hidden_size + pooler = Pooler(pooling_types) if not full_embeddings else None + tokenizer_mode = tokenizer is not None + if tokenizer_mode: + collate_fn = build_collator(tokenizer) + device = self.device + else: + collate_fn = None + device = None + + def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if full_embeddings or residue_embeddings.ndim == 2: + return residue_embeddings + return pooler(residue_embeddings, attention_mask) + + def iter_batches(to_embed: List[str]): + if tokenizer_mode: + assert collate_fn is not None + assert device is not None + dataset = ProteinDataset(to_embed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) + for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): + seqs = to_embed[i * batch_size:(i + 1) * batch_size] + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + residue_embeddings = self._embed(input_ids, attention_mask) + yield seqs, residue_embeddings, attention_mask + else: + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) + assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." + assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." + residue_embeddings, attention_mask = batch_output + assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." + yield seqs, residue_embeddings, attention_mask + + if sql: + conn = sqlite3.connect(sql_db_path) + self._ensure_embeddings_table(conn) + c = conn.cursor() + already_embedded = self._read_sequences_from_db(sql_db_path) + to_embed = [seq for seq in sequences if seq not in already_embedded] + print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") + print(f"Embedding {len(to_embed)} new sequences") + if len(to_embed) > 0: + with torch.no_grad(): + for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + emb_np = emb.cpu().numpy() + emb_shape = ",".join([str(dim) for dim in emb_np.shape]) + emb_dtype = str(emb_np.dtype) + c.execute( + "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", + (seq, emb_np.tobytes(), emb_shape, emb_dtype), + ) + if tokenizer_mode and (i + 1) % 100 == 0: + conn.commit() + conn.commit() + conn.close() + return None + + embeddings_dict = {} + if os.path.exists(save_path): + embeddings_dict = self.load_embeddings_from_pth(save_path) + to_embed = [seq for seq in sequences if seq not in embeddings_dict] + print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") + print(f"Embedding {len(to_embed)} new sequences") + else: + to_embed = sequences + print(f"Embedding {len(to_embed)} new sequences") + + if len(to_embed) > 0: + with torch.no_grad(): + for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + embeddings_dict[seq] = emb.cpu() + + if save: + torch.save(embeddings_dict, save_path) + + return embeddings_dict + + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from einops import rearrange, repeat +from enum import Enum +from typing import Any, TypedDict, Callable, List +from dataclasses import dataclass +from tokenizers import Tokenizer +from transformers import PretrainedConfig, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +### Kernels Flash Attention Detection +def _infer_kernels_flash_variant(kernel) -> str | None: + if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): + return "flash_attn2" + if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): + return "flash_attn3" + return None + + +def _try_get_kernels_flash(): + try: + from kernels import get_kernel + except ImportError: + return None, None + + flash_kernel = None + flash_kernel_variant = None + try: + flash_kernel = get_kernel("kernels-community/flash-attn3") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." + except Exception: + try: + flash_kernel = get_kernel("kernels-community/flash-attn2") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." + except Exception: + flash_kernel = None + flash_kernel_variant = None + return flash_kernel, flash_kernel_variant + + +FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() + + +def _kernels_flash_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) + except TypeError: + output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +def _kernels_flash_varlen_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_in_batch_q: int, + max_seqlen_in_batch_k: int, + causal: bool = False, +) -> torch.Tensor: + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.varlen_fwd( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + is_causal=causal, + )[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_varlen_func( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + causal=causal, + ) + except TypeError: + output = FLASH_KERNEL.flash_attn_varlen_func( + query_states, key_states, value_states, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_in_batch_q, max_seqlen_in_batch_k, + 0.0, None, causal, + ) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, + _create_sparse_block_from_block_mask +) + +_compiled_flex_attention = None + + +def _get_flex_attention_fn(): + """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" + global _compiled_flex_attention + flex_mod = torch.nn.attention.flex_attention + if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): + return flex_attention + if _compiled_flex_attention is None: + _compiled_flex_attention = torch.compile(flex_attention) + return _compiled_flex_attention + + +try: + from kernels import get_kernel + layer_norm = get_kernel("kernels-community/triton-layer-norm") +except Exception as e: + logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") + layer_norm = None + + +### Attention Backend Enum & Resolution +class AttentionBackend(Enum): + AUTO = "auto" + KERNELS_FLASH = "kernels_flash" + FLEX = "flex" + SDPA = "sdpa" + + +VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) + + +_BACKEND_CONFIRMED = False + + +def resolve_attention_backend(requested_backend: str) -> AttentionBackend: + global _BACKEND_CONFIRMED + assert requested_backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + if requested_backend == AttentionBackend.AUTO.value: + if FLASH_KERNEL is not None: + resolved = AttentionBackend.KERNELS_FLASH + elif flex_attention is not None: + resolved = AttentionBackend.FLEX + else: + resolved = AttentionBackend.SDPA + elif requested_backend == AttentionBackend.KERNELS_FLASH.value: + assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." + resolved = AttentionBackend.KERNELS_FLASH + elif requested_backend == AttentionBackend.FLEX.value: + assert flex_attention is not None, "Flex Attention is not available in this environment." + resolved = AttentionBackend.FLEX + elif requested_backend == AttentionBackend.SDPA.value: + resolved = AttentionBackend.SDPA + else: + raise AssertionError(f"Unsupported attention backend: {requested_backend}") + if not _BACKEND_CONFIRMED: + print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") + _BACKEND_CONFIRMED = True + return resolved + + +def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: + # Assumes sequence_ids is sorted in increasing order for each batch item, except for + # the -1 values, which are used to indicate the padding tokens. + def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] + return ( + (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) + & (sequence_ids[b, q_idx] != -1) + & (sequence_ids[b, kv_idx] != -1) + ) + + batch_size, seqlen = sequence_ids.shape + return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) + + +def create_within_seq_block_mask(sequence_ids: torch.Tensor) -> BlockMask: + def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] + return ( + (sequence_ids[b, q_idx] == sequence_ids[b, kv_idx]) + & (sequence_ids[b, q_idx] != -1) + & (sequence_ids[b, kv_idx] != -1) + ) + + batch_size, seqlen = sequence_ids.shape + return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) + + +def build_within_seq_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: + not_pad = (sequence_ids != -1) + same_seq = sequence_ids.unsqueeze(-1) == sequence_ids.unsqueeze(-2) + valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) + return (same_seq & valid).unsqueeze(1) + + +def build_block_causal_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: + not_pad = (sequence_ids != -1) + causal = sequence_ids.unsqueeze(-1) >= sequence_ids.unsqueeze(-2) + valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) + return (causal & valid).unsqueeze(1) + + +def flex_attention_func( + query_states: torch.Tensor, # (bs, seqlen, nh, hs) + key_states: torch.Tensor, # (bs, seqlen, nkv, hs) + value_states: torch.Tensor, # (bs, seqlen, nkv, hs) + score_mod: Callable | None = None, + block_mask: BlockMask | None = None, +) -> torch.Tensor: + assert flex_attention is not None, "Flex Attention is not available in this environment" + assert score_mod is None, "Score mod is not supported yet" + query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) + key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) + value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) + + fn = _get_flex_attention_fn() + outputs = fn( + query_states, + key_states, + value_states, + block_mask=block_mask, + score_mod=score_mod, + enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh + ) + + outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) + return outputs + + +def kernels_flash_attention_func( + query_states: torch.Tensor, # (bs, seqlen, nh, hs) + key_states: torch.Tensor, # (bs, seqlen, nkv, hs) + value_states: torch.Tensor, # (bs, seqlen, nkv, hs) + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: # (bs, seqlen, nh, hs) + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + + if not causal: + batch_size, q_len = query_states.shape[0], query_states.shape[1] + ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) + + attn_output_unpad = _kernels_flash_varlen_forward( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_in_batch_q=max_seqlen_in_batch_q, + max_seqlen_in_batch_k=max_seqlen_in_batch_k, + causal=False, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + + else: + attn_output = _kernels_flash_forward(query_states, key_states, value_states, causal=True) + + return attn_output + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def] + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( + -1, *other_shape + ) + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def] + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + device = SLEN.device + total_tokens = torch.sum(SLEN) + B = (total_tokens + block_size - 1) // block_size + padding_tokens = B * block_size - total_tokens + SLEN = torch.cat([SLEN, padding_tokens.reshape(1).to(device=device, dtype=SLEN.dtype)], dim=0) + + assert torch.sum(SLEN) == B * block_size + + # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i + cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) + total_tokens = cum[-1].item() + + # Block start/end offsets [start, end) in token index space + block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) + block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) + + # MIN_SEQ_ID[i] = first sequence whose end > block_start + # searchsorted with right=True returns first index where cum > value + MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) + + # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) + # For empty tail beyond total_tokens we already clipped block_ends. + last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token + MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) + + return MIN_SEQ_ID, MAX_SEQ_ID + + +def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) + MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) + + cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) + cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) + overlap = cond1 & cond2 + + cond1 = (MIN_Q == MAX_Q).unsqueeze(1) + cond2 = (MIN_K == MAX_K).unsqueeze(0) + same_seq_in_qk = cond1 & cond2 + + full_blocks = overlap & same_seq_in_qk + partial_blocks = overlap & ~same_seq_in_qk + + return full_blocks, partial_blocks + + +@torch.compiler.disable +def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: + full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) + partial_blocks = partial_blocks[None, None] + full_blocks = full_blocks[None, None] + + q_doc_id = torch.repeat_interleave(SLEN_Q) + k_doc_id = torch.repeat_interleave(SLEN_K) + + def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return q_doc_id[q_idx] == k_doc_id[kv_idx] + + total_q_len = q_doc_id.shape[0] + total_k_len = k_doc_id.shape[0] + + return _create_sparse_block_from_block_mask( + (partial_blocks, full_blocks), + doc_mask, + seq_lengths=(total_q_len, total_k_len), + Q_BLOCK_SIZE=128, + KV_BLOCK_SIZE=128, + ) + + +@torch.compiler.disable +def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: + q_doc_id = torch.repeat_interleave(SLEN_Q) + k_doc_id = torch.repeat_interleave(SLEN_K) + + def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return q_doc_id[q_idx] == k_doc_id[kv_idx] + + total_q_len = q_doc_id.shape[0] + total_k_len = k_doc_id.shape[0] + + return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) + + +def varlen_flex_attention_func( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, +) -> torch.Tensor: + batch_size, q_len = query_states.shape[0], query_states.shape[1] + ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) + + query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() + key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() + value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() + + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + block_mask = block_mask_creator(seqlens_q, seqlens_k) + + fn = _get_flex_attention_fn() + attn_output_unpad = fn( + query_states, + key_states, + value_states, + block_mask=block_mask, + enable_gqa=query_states.shape[1] != key_states.shape[1], + ) + + attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) + + return attn_output + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def] + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def] + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + non_pad_indices = sequence_ids != -1 + non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() + sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 + sequence_ids = sequence_ids.flatten()[non_pad_indices] + _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return non_pad_indices, cu_seqlens, max_seqlen_in_batch + + +def _unpad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] + assert query_layer.shape[:2] == q_sequence_ids.shape, ( + f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" + ) + assert key_layer.shape[:2] == k_sequence_ids.shape, ( + f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" + ) + assert query_length <= kv_seq_len, ( + f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if torch.equal(q_sequence_ids, k_sequence_ids): + indices_q = indices_k + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + else: + indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) + + query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) + + assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( + f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +index_first_axis = IndexFirstAxis.apply +block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask +PAD_TOKEN_ID = 0 + + +def get_tokenizer() -> Tokenizer: + try: + fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") + tokenizer: Tokenizer = Tokenizer.from_file(fname) + except Exception: + print("E1 Tokenizer not found in local directory, downloading from Hugging Face") + from huggingface_hub import hf_hub_download + fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") + tokenizer: Tokenizer = Tokenizer.from_file(fname) + assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( + f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" + ) + + return tokenizer + + +@dataclass +class DataPrepConfig: + max_num_sequences: int = 512 + max_num_positions_within_seq: int = 8192 + remove_X_tokens: bool = False + + +def get_context(sequence: str) -> str | None: + if "," in sequence: + return sequence.rsplit(",", 1)[0] + return None + + +class E1BatchPreparer: + def __init__( + self, + data_prep_config: DataPrepConfig | None = None, + tokenizer: Tokenizer | None = None, + preserve_context_labels: bool = False, + ): + self.tokenizer = tokenizer or get_tokenizer() + self.data_prep_config = data_prep_config or DataPrepConfig() + self.pad_token_id = self.tokenizer.token_to_id("") + self.preserve_context_labels = preserve_context_labels + device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + self.boundary_token_ids = torch.tensor( + [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device + ).long() + self.mask_token = "?" # nosec + self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) + self.X_token_id = self.tokenizer.token_to_id("X") + self.vocab = self.tokenizer.get_vocab() + + def get_batch_kwargs( # type: ignore[override] + self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False + ) -> dict[str, torch.Tensor | list[str] | list[int]]: + sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] + return self.pad_encodings(sequence_encodings, device, non_blocking) + + def pad_encodings( + self, + sequence_encodings: list[dict[str, torch.Tensor]], + device: torch.device = torch.device("cpu"), + non_blocking: bool = False, + ) -> dict[str, torch.Tensor | list[str] | list[int]]: + non_blocking = non_blocking and device.type == "cuda" + padded_encodings = {} + # Note: We use -1 as the padding value for sequence and position ids because the 0 value + # is a valid value for sequence and position ids. -1 is then used to distinguish valid + # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. + for key, padding_value in { + "input_ids": self.pad_token_id, + "sequence_ids": -1, + "within_seq_position_ids": -1, + "global_position_ids": -1, + "labels": self.pad_token_id, + }.items(): + padded_encodings[key] = pad_sequence( + [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value + ).to(device=device, dtype=torch.long, non_blocking=non_blocking) + + padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] + padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] + + return padded_encodings + + def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: + single_sequences = sequence.split(",") + if len(single_sequences) > self.data_prep_config.max_num_sequences: + raise ValueError( + f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" + " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." + ) + + single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] + + num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] + input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) + labels = torch.cat([x["labels"] for x in single_sequence_encodings]) + + within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) + global_position_ids, ctx_len = [], 0 + for encoding in single_sequence_encodings: + global_position_ids.append(encoding["position_ids"] + ctx_len) + ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) + global_position_ids = torch.cat(global_position_ids) + + sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) + + # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired + context_len = sum(num_tokens[:-1]) + context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) + if not self.preserve_context_labels: + labels[:context_len] = self.pad_token_id + + assert ( + input_ids.shape + == sequence_ids.shape + == within_seq_position_ids.shape + == global_position_ids.shape + == labels.shape + ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" + + assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" + + return { + "input_ids": input_ids, + "sequence_ids": sequence_ids, + "within_seq_position_ids": within_seq_position_ids, + "global_position_ids": global_position_ids, + "labels": labels, + "context": context, + "context_len": context_len, + } + + def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: + if not self.validate_sequence(sequence): + raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") + + if len(sequence) > self.data_prep_config.max_num_positions_within_seq: + raise ValueError( + f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" + ) + + # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` + # but following is faster since our vocabulary is simple. + tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) + position_ids = torch.arange(len(tokens)) + + if self.data_prep_config.remove_X_tokens: + X_positions = torch.where(tokens != self.X_token_id)[0] + tokens = tokens[X_positions] + position_ids = position_ids[X_positions] + + return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} + + def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: + return torch.isin(tokens, self.boundary_token_ids.to(tokens.device)) + + def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: + return tokens == self.mask_token_id + + def validate_sequence(self, sequence: str) -> bool: + assert isinstance(sequence, str), "Sequence must be a string" + sequence = sequence.replace(self.mask_token, "") + return sequence.isalpha() and sequence.isupper() + + +class E1Config(PretrainedConfig): + model_type = "E1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( # type: ignore + self, + # Model architecture/initialization + vocab_size=None, + hidden_size=4096, + intermediate_size=16384, + gated_mlp=False, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + rms_norm_eps=1e-5, + initializer_range=0.02, + dtype="bfloat16", + gradient_checkpointing=False, + no_ffn_gradient_checkpointing=False, + # Tokenization + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + # Attention implementation & rotary positional embeddings + global_attention_every_n_layers=0, + max_num_sequences=512, + max_num_positions_within_seq=8192, + max_num_positions_global=1024 * 128, + rope_theta_within_seq=10000.0, + rope_theta_global=100000.0, + clip_qkv=None, + attn_backend="sdpa", + **kwargs, + ) -> None: + tokenizer = get_tokenizer() + super().__init__( + pad_token_id=tokenizer.token_to_id(""), + bos_token_id=tokenizer.token_to_id(""), + eos_token_id=tokenizer.token_to_id(""), + tie_word_embeddings=tie_word_embeddings, + dtype=dtype, + **kwargs, + ) + + self.hidden_size = hidden_size + if intermediate_size is None: + intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size + self.intermediate_size = intermediate_size + self.gated_mlp = gated_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_num_positions_within_seq = max_num_positions_within_seq + self.max_num_positions_global = max_num_positions_global + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta_within_seq = rope_theta_within_seq + self.rope_theta_global = rope_theta_global + self.max_num_sequences = max_num_sequences + assert clip_qkv is None or clip_qkv > 0 + self.clip_qkv = clip_qkv + self.global_attention_every_n_layers = global_attention_every_n_layers + + self.vocab_size = tokenizer.get_vocab_size() + self.gradient_checkpointing = gradient_checkpointing + self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing + self.attn_backend = attn_backend + + if vocab_size is not None: + if vocab_size < self.vocab_size: + logger.warning( + f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." + ) + self.vocab_size = vocab_size + elif vocab_size > self.vocab_size: + logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") + self.vocab_size = vocab_size + if pad_token_id is not None and pad_token_id != self.pad_token_id: + logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") + if bos_token_id is not None and bos_token_id != self.bos_token_id: + logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") + if eos_token_id is not None and eos_token_id != self.eos_token_id: + logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") + + +class DynamicCache: + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. + + Args: + key_cache (`list[torch.Tensor]`): The list of key states. + value_cache (`list[torch.Tensor]`): The list of value states. + """ + + def __init__(self) -> None: + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + def update( + self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] + value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] + layer_idx (`int`): The index of the layer to update. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. + """ + # Lazy initialization + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 + return layer_seq_length + + def crop(self, max_length: int) -> None: + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + assert max_length > 0, "max_length must be positive" + + if self.get_seq_length() <= max_length: + return + + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +class KVCache: + def __init__(self, cache_size: int = 4) -> None: + self.cache_size = cache_size + self.tensor_input_field_names = [ + "input_ids", + "within_seq_position_ids", + "global_position_ids", + "sequence_ids", + "labels", + ] + self.tensor_output_field_names = ["logits", "embeddings"] + self.cache_dict: dict[str, DynamicCache] = {} + self.cache_queue: list[str] = [] + + def reset(self) -> None: + for k in list(self.cache_dict.keys()): + del self.cache_dict[k] + del self.cache_dict + self.cache_dict = {} + self.cache_queue = [] + + torch.cuda.empty_cache() + + def before_forward(self, batch: dict[str, torch.Tensor]) -> None: + contexts: list[str] | None = batch.get("context", None) + if contexts is None or "context_len" not in batch: + logger.warning_once( + "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." + ) + return + + context_lens: list[int] = list(set(batch["context_len"])) + contexts: list[str] = list(set(contexts)) # type: ignore[no-redef] + if len(contexts) != 1 or len(context_lens) != 1: + logger.warning( + "SingleContextKVCache requires a single context and context length. " + "Multiple contexts or context lengths found in a single batch. Skipping." + ) + return + + batch_size = batch["input_ids"].shape[0] + + unique_context = contexts[0] + unique_context_len = context_lens[0] + batch["use_cache"] = True + + if unique_context not in self.cache_dict: + return + + self.cache_dict[unique_context].batch_repeat_interleave(batch_size) + past_key_values = self.cache_dict[unique_context] + batch["past_key_values"] = past_key_values + + # Remove context from the input fields + for field_name in self.tensor_input_field_names: + if batch.get(field_name, None) is not None: + batch[field_name] = batch[field_name][:, unique_context_len:] + + def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None: + contexts = batch.get("context", None) + context_lens = batch.get("context_len", []) + if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: + return + + assert batch["use_cache"] + unique_context = contexts[0] + unique_context_len = context_lens[0] + + past_key_values = getattr(outputs, "past_key_values", None) + if not isinstance(past_key_values, DynamicCache): + logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") + return + + if "past_key_values" not in batch: + if len(self.cache_queue) == self.cache_size: + last_context = self.cache_queue.pop(0) + if last_context not in self.cache_queue: + del self.cache_dict[last_context] + torch.cuda.empty_cache() + + self.cache_dict[unique_context] = past_key_values + self.cache_queue.append(unique_context) + + # Remove context from the input fields + for field_name in self.tensor_input_field_names: + if field_name in batch and batch[field_name] is not None: + batch[field_name] = batch[field_name][:, unique_context_len:] + + # Remove context from the output fields + for field_name in self.tensor_output_field_names: + if field_name in outputs and outputs[field_name] is not None: + outputs[field_name] = outputs[field_name][:, unique_context_len:] + if "hidden_states" in outputs and outputs["hidden_states"] is not None: + outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] + + self.cache_dict[unique_context].crop(unique_context_len) + self.cache_dict[unique_context].batch_select_indices([0]) + + +class AttentionLayerType(Enum): + WITHIN_SEQ = "within_seq" + GLOBAL = "global" + + +class AttentionArgs(TypedDict, total=False): + within_seq_block_mask: BlockMask | None + block_causal_block_mask: BlockMask | None + within_seq_mask_4d: torch.Tensor | None + block_causal_mask_4d: torch.Tensor | None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, + num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RotaryPositionalEmbedding(nn.Module): + def __init__( + self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None + ): + super().__init__() + + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) + + @staticmethod + def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: + # Different from paper, but it uses a different permutation in order to obtain the same calculation + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + angles = torch.outer(t, self.inv_freq.to(device)) + angles = torch.cat((angles, angles), dim=1) + self.register_buffer("cos_cached", angles.cos(), persistent=False) + self.register_buffer("sin_cached", angles.sin(), persistent=False) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + # x: [bsz, seq_len, num_attention_heads, head_size] + device, dtype = q.device, q.dtype + seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len + + if seq_len > self.max_seq_len_cached: + self._set_sin_cos_cache(seq_len=seq_len, device=device) + + # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), + # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). + idxs = position_ids.to(device) + cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] + sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] + + # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is + # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is + # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.max_num_seqs = config.max_num_sequences + self.clip_qkv = config.clip_qkv + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + if self.config.global_attention_every_n_layers > 0: + self.layer_type = ( + AttentionLayerType.GLOBAL + if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 + else AttentionLayerType.WITHIN_SEQ + ) + else: + self.layer_type = AttentionLayerType.WITHIN_SEQ + + self.rope_theta = ( + config.rope_theta_within_seq + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else config.rope_theta_global + ) + self.max_position_embeddings = ( + config.max_num_positions_within_seq + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else config.max_num_positions_global + ) + + self.rotary_emb = RotaryPositionalEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta + ) + + self.attn_backend = resolve_attention_backend(config.attn_backend) + + def prepare_qkv( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + past_key_value: DynamicCache | None = None, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, q_len, _ = hidden_states.size() + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + val_states: torch.Tensor = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + if self.clip_qkv is not None: + query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) + key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) + val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) + + query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) + + if use_cache and past_key_value is not None: + key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) + + input_dtype = query_states.dtype + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + if input_dtype != target_dtype: + logger.warning_once( + f"The input hidden states seems to be silently casted in {input_dtype}. " + f"This might be because you have upcasted embedding or layer norm layers " + f"in {input_dtype}. We will cast back the input in {target_dtype}." + ) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + val_states = val_states.to(target_dtype) + + return query_states, key_states, val_states + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + output_s_max: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: + is_cache_prefilled = ( + use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 + ) + + query_states, key_states, val_states = self.prepare_qkv( + hidden_states=hidden_states, + position_ids=within_seq_position_ids + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else global_position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + ) + + attn_output, attn_weights, s_max = self._attn( + query_states=query_states, + key_states=key_states, + val_states=val_states, + sequence_ids=sequence_ids, + attention_args=attention_args, + output_attentions=output_attentions, + output_s_max=output_s_max, + is_cache_prefilled=is_cache_prefilled, + ) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value, s_max + + def _attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + output_attentions: bool = False, + output_s_max: bool = False, + is_cache_prefilled: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: + effective_layer_type = self.layer_type + if is_cache_prefilled and self.layer_type == AttentionLayerType.GLOBAL: + effective_layer_type = AttentionLayerType.WITHIN_SEQ + + if output_attentions: + return self._manual_attn( + query_states, key_states, val_states, + sequence_ids=sequence_ids, + attention_args=attention_args, + effective_layer_type=effective_layer_type, + output_s_max=output_s_max, + is_cache_prefilled=is_cache_prefilled, + ) + + if self.attn_backend == AttentionBackend.KERNELS_FLASH: + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + attn_output, attn_weights = self._kernels_flash_attn( + query_states, key_states, val_states, + sequence_ids=sequence_ids, + is_cache_prefilled=is_cache_prefilled, + ) + else: + attn_output, attn_weights = self._flex_attn( + query_states, key_states, val_states, + attention_args=attention_args, + effective_layer_type=effective_layer_type, + ) + elif self.attn_backend == AttentionBackend.FLEX: + attn_output, attn_weights = self._flex_attn( + query_states, key_states, val_states, + attention_args=attention_args, + effective_layer_type=effective_layer_type, + ) + elif self.attn_backend == AttentionBackend.SDPA: + attn_output, attn_weights = self._sdpa_attn( + query_states, key_states, val_states, + sequence_ids=sequence_ids, + attention_args=attention_args, + effective_layer_type=effective_layer_type, + is_cache_prefilled=is_cache_prefilled, + ) + else: + raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") + + s_max = self._compute_s_max(query_states, key_states) if output_s_max else None + return attn_output, attn_weights, s_max + + @torch.no_grad() + def _compute_s_max( + self, + query_states: torch.Tensor, # (B, L, H, D) + key_states: torch.Tensor, # (B, L, Hkv, D) + ) -> list[torch.Tensor]: + query_BHLD = query_states.transpose(1, 2).contiguous() + key_BHLD = key_states.transpose(1, 2).contiguous() + key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) + scale = 1.0 / (self.head_dim ** 0.5) + q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) + k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) + s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * scale + return [s_max_bound[h] for h in range(self.num_heads)] + + def _kernels_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + sequence_ids: torch.Tensor, + is_cache_prefilled: bool = False, + ) -> tuple[torch.Tensor, None]: + bsz, q_len = query_states.shape[0], query_states.shape[1] + _, kv_len = key_states.shape[0], key_states.shape[1] + + if self.layer_type == AttentionLayerType.GLOBAL and not is_cache_prefilled: + q_sequence_ids = sequence_ids + if q_len < kv_len: + first_token_id = sequence_ids[:, 0].unsqueeze(1) + k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) + else: + k_sequence_ids = sequence_ids + else: + if q_len < kv_len: + key_states = key_states[:, -q_len:] + val_states = val_states[:, -q_len:] + q_sequence_ids = k_sequence_ids = sequence_ids + + attn_output = kernels_flash_attention_func( + query_states, key_states, val_states, + q_sequence_ids=q_sequence_ids, + k_sequence_ids=k_sequence_ids, + causal=False, + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + return attn_output, None + + def _flex_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + attention_args: AttentionArgs | None = None, + effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, + ) -> tuple[torch.Tensor, None]: + bsz, q_len = query_states.shape[0], query_states.shape[1] + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + block_mask = attention_args["within_seq_block_mask"] if attention_args is not None else None + else: + block_mask = attention_args["block_causal_block_mask"] if attention_args is not None else None + outputs = flex_attention_func(query_states, key_states, val_states, block_mask=block_mask) + outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() + return outputs, None + + def _sdpa_attn( + self, + query_states: torch.Tensor, # (B, L, H, D) + key_states: torch.Tensor, # (B, L, Hkv, D) + val_states: torch.Tensor, # (B, L, Hkv, D) + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, + is_cache_prefilled: bool = False, + ) -> tuple[torch.Tensor, None]: + bsz, q_len = query_states.shape[:2] + kv_len = key_states.shape[1] + + if is_cache_prefilled and q_len < kv_len: + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + key_states = key_states[:, -q_len:] + val_states = val_states[:, -q_len:] + attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None + elif attention_args is not None: + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + attention_mask_4d = attention_args["within_seq_mask_4d"] + else: + attention_mask_4d = attention_args["block_causal_mask_4d"] + else: + attention_mask_4d = None + + query_BHLD = query_states.transpose(1, 2).contiguous() + key_BHLD = key_states.transpose(1, 2).contiguous() + val_BHLD = val_states.transpose(1, 2).contiguous() + key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) + val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) + context_BHLD = F.scaled_dot_product_attention(query_BHLD, key_BHLD, val_BHLD, attn_mask=attention_mask_4d) + attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() + return attn_output, None + + def _manual_attn( + self, + query_states: torch.Tensor, # (B, L, H, D) + key_states: torch.Tensor, # (B, L, Hkv, D) + val_states: torch.Tensor, # (B, L, Hkv, D) + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, + output_s_max: bool = False, + is_cache_prefilled: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: + bsz, q_len = query_states.shape[:2] + kv_len = key_states.shape[1] + + if is_cache_prefilled and q_len < kv_len: + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + key_states = key_states[:, -q_len:] + val_states = val_states[:, -q_len:] + attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None + elif attention_args is not None: + if effective_layer_type == AttentionLayerType.WITHIN_SEQ: + attention_mask_4d = attention_args["within_seq_mask_4d"] + else: + attention_mask_4d = attention_args["block_causal_mask_4d"] + else: + attention_mask_4d = None + + query_BHLD = query_states.transpose(1, 2).contiguous() + key_BHLD = key_states.transpose(1, 2).contiguous() + val_BHLD = val_states.transpose(1, 2).contiguous() + key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) + val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) + scale = 1.0 / (self.head_dim ** 0.5) + attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale + if attention_mask_4d is not None: + attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) + attn_weights = F.softmax(attn_weights, dim=-1) + context_BHLD = torch.matmul(attn_weights, val_BHLD) + attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() + s_max = self._compute_s_max(query_states, key_states) if output_s_max else None + return attn_output, attn_weights, s_max + + +class MLP(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.w2(self.act_fn(self.w1(hidden_states))) + + +class GLUMLP(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + hidden_states = self.w2(hidden_states) + return hidden_states + + +class FFN(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + mlp_cls = GLUMLP if config.gated_mlp else MLP + self.mlp = mlp_cls(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(hidden_states) + + +@dataclass +class E1ModelOutputWithPast(ModelOutput): + """Base class for model's outputs, with potential hidden states and attentions. + + Attributes: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + s_max: tuple[list[torch.Tensor], ...] | None = None + + +@dataclass +class E1MaskedLMOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + mlm_loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + s_max: tuple[list[torch.Tensor], ...] | None = None + + +@dataclass +class E1ClassificationOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + s_max: tuple[list[torch.Tensor], ...] | None = None + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + if layer_norm is None: + return torch.nn.functional.rms_norm( + hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon + ).to(input_dtype) + else: + return layer_norm.rms_norm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # no bias + residual=None, + eps=self.variance_epsilon, + dropout_p=0.0, # no dropout by default + prenorm=False, + residual_in_fp32=False, + ).to(input_dtype) + + +class NormAttentionNorm(nn.Module): + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.self_attn = Attention(config, layer_idx) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + output_s_max: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value, s_max = self.self_attn( + hidden_states=hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_s_max=output_s_max, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + return hidden_states, residual, self_attn_weights, present_key_value, s_max + + +class DecoderLayer(nn.Module): + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.initializer_range = config.initializer_range + self.hidden_size = config.hidden_size + self.norm_attn_norm = NormAttentionNorm(config, layer_idx) + self.ffn = FFN(config) + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + output_s_max: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None, list[torch.Tensor] | None]: + hidden_states, residual, self_attn_weights, present_key_value, s_max = self.norm_attn_norm( + hidden_states=hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_s_max=output_s_max, + use_cache=use_cache, + ) + + # Fully Connected + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights, present_key_value, s_max + + +class E1PreTrainedModel(PreTrainedModel): + config_class = E1Config + config: E1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DecoderLayer"] + _transformer_layer_cls = [DecoderLayer] + _skip_keys_device_placement = "past_key_values" + all_tied_weights_keys = {} + + def _init_weights(self, module: nn.Module) -> None: + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + def _backward_compatibility_gradient_checkpointing(self) -> None: + if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable(dict(use_reentrant=False)) + + def post_init(self) -> None: + super().post_init() + + @property + def _device(self) -> torch.device: + return next(self.parameters()).device + + @property + def attn_backend(self) -> str: + return self.config.attn_backend + + @attn_backend.setter + def attn_backend(self, backend: str) -> None: + assert backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + self.config.attn_backend = backend + resolved = resolve_attention_backend(backend) + for module in self.modules(): + if isinstance(module, FAST_E1_ENCODER): + module._attn_backend = resolved + elif isinstance(module, Attention): + module.attn_backend = resolved + + +class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) + self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = E1BatchPreparer() + self._attn_backend = resolve_attention_backend(config.attn_backend) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.embed_tokens = value + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_s_max: bool = False, + **kwargs + ) -> E1ModelOutputWithPast: + """ + Args: + input_ids: (batch_size, seq_length) + within_seq_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the sequence itself. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] + global_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the global sequence. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] + sequence_ids: (batch_size, seq_length) + This tensor contains the sequence id of each residue. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] + past_key_values: DynamicCache + use_cache: bool + output_attentions: bool + output_hidden_states: bool + output_s_max: bool + + Returns: + E1ModelOutputWithPast: Model Outputs + """ + batch_size, seq_length = input_ids.shape + + if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + elif not use_cache: + past_key_values = None + + global_position_ids = global_position_ids.view(-1, seq_length).long() + within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() + sequence_ids = sequence_ids.view(-1, seq_length).long() + + max_position_id = torch.max(within_seq_position_ids).item() + min_position_id = torch.min(within_seq_position_ids).item() + assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( + f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" + ) + + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) + + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype + hidden_states = inputs_embeds.to(target_dtype) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + attn_backend = self._attn_backend + has_global_layers = self.config.global_attention_every_n_layers > 0 + needs_4d_masks = (attn_backend == AttentionBackend.SDPA) or output_attentions + needs_block_causal_flex = ( + (attn_backend == AttentionBackend.FLEX and has_global_layers) + or (attn_backend == AttentionBackend.KERNELS_FLASH and has_global_layers) + ) + needs_within_seq_flex = (attn_backend == AttentionBackend.FLEX) + + attention_args: AttentionArgs | None = None + if past_key_values_length == 0: + attention_args = AttentionArgs( + block_causal_block_mask=create_block_causal_mask_optimized(sequence_ids) if needs_block_causal_flex else None, + within_seq_block_mask=create_within_seq_block_mask(sequence_ids) if needs_within_seq_flex else None, + within_seq_mask_4d=build_within_seq_mask_4d(sequence_ids) if needs_4d_masks else None, + block_causal_mask_4d=build_block_causal_mask_4d(sequence_ids) if needs_4d_masks else None, + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + full_s_max = () if output_s_max else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + within_seq_position_ids, + global_position_ids, + sequence_ids, + attention_args, + past_key_values, + output_attentions, + output_s_max, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_s_max=output_s_max, + use_cache=use_cache, + ) + + hidden_states, self_attn_weights, present_key_value, s_max = layer_outputs + + if use_cache: + next_decoder_cache = past_key_values = present_key_value + + if output_attentions: + all_self_attns += (self_attn_weights,) # type: ignore[operator] + + if full_s_max is not None: + full_s_max += (s_max,) # type: ignore[operator] + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + next_cache = next_decoder_cache if use_cache else None + + return E1ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + s_max=full_s_max, + ) + + +class E1Model(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) + self.prep_tokens = self.model.prep_tokens + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.model.set_input_embeddings(value) + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_s_max: bool = False, + **kwargs, + ) -> E1ModelOutputWithPast: + return self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + **kwargs, + ) + + +class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) + self.vocab_size = config.vocab_size + self.mlm_head = torch.nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size, bias=True), + nn.GELU(), + nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, config.vocab_size, bias=True), + ) + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = self.model.prep_tokens + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_s_max: bool = False, + **kwargs, + ) -> E1MaskedLMOutputWithPast: + """ + Args: + input_ids: (batch_size, seq_length) + within_seq_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the sequence itself. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] + global_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the global sequence. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] + sequence_ids: (batch_size, seq_length) + This tensor contains the sequence id of each residue. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] + labels: (batch_size, seq_length) + past_key_values: DynamicCache + use_cache: bool + output_attentions: bool + output_hidden_states: bool + output_s_max: bool + + Returns: + E1MaskedLMOutputWithPast: Model Outputs + """ + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + ) + + last_hidden_state = outputs.last_hidden_state + loss = None + + mlm_logits = self.mlm_head(last_hidden_state).float() + mlm_loss = 0.0 + if labels is not None: + mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) + mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) + mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") + mask = mlm_labels_flat != self.model.padding_idx + n_mlm = mask.sum() + mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) + loss = 0.0 + loss += mlm_loss + + return E1MaskedLMOutputWithPast( + loss=loss, + mlm_loss=mlm_loss, + logits=mlm_logits, + last_hidden_state=last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + s_max=outputs.s_max, + ) + + +class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) + self.vocab_size = config.vocab_size + self.num_labels = config.num_labels + self.classifier = nn.Sequential( + nn.Linear(config.hidden_size * 2, config.hidden_size * 4), + nn.GELU(), + nn.LayerNorm(config.hidden_size * 4), + nn.Linear(config.hidden_size * 4, config.num_labels), + ) + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + self.bce = nn.BCEWithLogitsLoss() + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = self.model.prep_tokens + + if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: + pooling_types = kwargs['pooling_types'] + else: + pooling_types = ['mean', 'var'] + self.pooler = Pooler(pooling_types) + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_s_max: bool = False, + **kwargs, + ) -> E1ClassificationOutputWithPast: + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + ) + + attention_mask = (sequence_ids != -1).long() + x = outputs.last_hidden_state + features = self.pooler(x, attention_mask) + logits = self.classifier(features) + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + if self.num_labels == 1: + loss = self.mse(logits.flatten(), labels.flatten()) + else: + loss = self.mse(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = self.bce(logits, labels) + + return E1ClassificationOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=x, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + s_max=outputs.s_max, + ) + + +class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) + self.vocab_size = config.vocab_size + self.num_labels = config.num_labels + self.classifier = nn.Sequential( + nn.Linear(config.hidden_size * 2, config.hidden_size * 4), + nn.GELU(), + nn.LayerNorm(config.hidden_size * 4), + nn.Linear(config.hidden_size * 4, config.num_labels), + ) + self.loss_fct = nn.CrossEntropyLoss() + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = self.model.prep_tokens + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_s_max: bool = False, + **kwargs, + ) -> E1ClassificationOutputWithPast: + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + ) + + x = outputs.last_hidden_state + logits = self.classifier(x) + loss = None + if labels is not None: + loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return E1ClassificationOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=x, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + s_max=outputs.s_max, + ) + + +if __name__ == "__main__": + import random + + import torch + + from torch import Tensor + + def print_tensor_shapes(prefix: str, obj): + if isinstance(obj, Tensor): + print(f"{prefix}{obj.shape}") + elif isinstance(obj, dict): + for name, value in obj.items(): + print_tensor_shapes(f"{prefix}{name}.", value) + elif isinstance(obj, list): + for idx, value in enumerate(obj): + print_tensor_shapes(f"{prefix}[{idx}].", value) + elif isinstance(obj, tuple): + for idx, value in enumerate(obj): + print_tensor_shapes(f"{prefix}[{idx}].", value) + elif hasattr(obj, "__dict__"): + for name, value in vars(obj).items(): + if name.startswith("_"): + continue + print_tensor_shapes(f"{prefix}{name}.", value) + else: + print(f"{prefix}{type(obj)}") + + def get_e1_batch(tokenizer, sequences: list[str], device: torch.device): + preparer = E1BatchPreparer(data_prep_config=DataPrepConfig(max_num_positions_within_seq=64), tokenizer=tokenizer) + return preparer.get_batch_kwargs(sequences=sequences, device=device) + + random.seed(0) + torch.manual_seed(0) + + num_attention_heads = random.choice([2, 4]) + config = E1Config( + hidden_size=16 * num_attention_heads, + intermediate_size=64 * num_attention_heads, + num_hidden_layers=random.choice([1, 2]), + num_attention_heads=num_attention_heads, + num_key_value_heads=num_attention_heads, + max_num_positions_within_seq=128, + max_num_positions_global=256, + max_num_sequences=8, + dtype="float32", + ) + model = E1ForMaskedLM(config=config).eval() + tokenizer = get_tokenizer() + batch = get_e1_batch(tokenizer=tokenizer, sequences=["ACDEFG", "MKTW"], device=torch.device("cpu")) + batch["labels"] = batch["labels"].clone() + + with torch.no_grad(): + output = model( + input_ids=batch["input_ids"], + within_seq_position_ids=batch["within_seq_position_ids"], + global_position_ids=batch["global_position_ids"], + sequence_ids=batch["sequence_ids"], + labels=batch["labels"], + ) + + print("Batch shape:") + print_tensor_shapes("", batch) + print("Output shape:") + print_tensor_shapes("", output) + +