diff --git "a/modeling_e1.py" "b/modeling_e1.py" --- "a/modeling_e1.py" +++ "b/modeling_e1.py" @@ -1,2135 +1,2137 @@ -import os -os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" - -import numpy as np -import networkx as nx -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence - -from einops import rearrange, repeat -from enum import Enum -from typing import Any, TypedDict, Callable, Optional, List -from dataclasses import dataclass -from tokenizers import Tokenizer -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ModelOutput -from transformers.utils import logging -from tqdm.auto import tqdm - - -logger = logging.get_logger(__name__) - -### Establish attention compatibility -try: - from flash_attn import flash_attn_func, flash_attn_varlen_func -except ImportError: - logger.warning("Failed to import flash attention; Will be using PyTorch attention instead") - flash_attn_func = None - flash_attn_varlen_func = None - -try: - from torch.nn.attention.flex_attention import ( - BlockMask, - create_block_mask, - flex_attention, - _create_sparse_block_from_block_mask - ) - - if torch.cuda.is_available(): - # if on linux, compile the flex attention function - if os.name == 'posix': - print("Compiling flex attention") - flex_attention = torch.compile(flex_attention, dynamic=True) - else: - print("Not compiling flex attention, detected non-Linux environment") - -except ImportError: - logger.warning("Failed to import flex attention; Will be using PyTorch attention instead") - flex_attention = None - -try: - from kernels import get_kernel - layer_norm = get_kernel("kernels-community/triton-layer-norm") -except Exception as e: - logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") - layer_norm = None - - -def is_flash_attention_available() -> bool: - return ( - flash_attn_func is not None and flash_attn_varlen_func is not None and (os.getenv("USE_FLASH_ATTN", "1") == "1") - ) - - -class FlexAttentionArgs(TypedDict, total=False): - block_mask: BlockMask | None - score_mod: Callable | None - - -def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: - # Assumes sequence_ids is sorted in increasing order for each batch item, except for - # the -1 values, which are used to indicate the padding tokens. - def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] - return ( - (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) - & (sequence_ids[b, q_idx] != -1) - & (sequence_ids[b, kv_idx] != -1) - ) - - batch_size, seqlen = sequence_ids.shape - return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) - - -def flex_attention_func( - query_states: torch.Tensor, # (bs, seqlen, nh, hs) - key_states: torch.Tensor, # (bs, seqlen, nkv, hs) - value_states: torch.Tensor, # (bs, seqlen, nkv, hs) - score_mod: Callable | None = None, - block_mask: BlockMask | None = None, -) -> torch.Tensor: - assert flex_attention is not None, "Flex Attention is not available in this environment" - assert score_mod is None, "Score mod is not supported yet" - query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) - key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) - value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) - - outputs = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - score_mod=score_mod, - enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh - ) - - outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) - return outputs - - -def flash_attention_func( - query_states: torch.Tensor, # (bs, seqlen, nh, hs) - key_states: torch.Tensor, # (bs, seqlen, nkv, hs) - value_states: torch.Tensor, # (bs, seqlen, nkv, hs) - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, - causal: bool = False, -) -> torch.Tensor: # (bs, seqlen, nh, hs) - # Contains at least one padding token in the sequence. Note: ignore attention mask if causal. - if not is_flash_attention_available(): - raise ImportError("Flash Attention is not available. Please install flash-attn.") - - if not causal: - batch_size, q_len = query_states.shape[0], query_states.shape[1] - ( - query_states, - key_states, - value_states, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - causal=False, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) - - else: - attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) - - return attn_output - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def] - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( - -1, *other_shape - ) - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def] - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - device = SLEN.device - total_tokens = torch.sum(SLEN) - B = (total_tokens + block_size - 1) // block_size - padding_tokens = B * block_size - total_tokens - SLEN = torch.cat([SLEN, torch.Tensor([padding_tokens]).to(device)], dim=0) - - assert torch.sum(SLEN) == B * block_size - - # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i - cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) - total_tokens = cum[-1].item() - - # Block start/end offsets [start, end) in token index space - block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) - block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) - - # MIN_SEQ_ID[i] = first sequence whose end > block_start - # searchsorted with right=True returns first index where cum > value - MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) - - # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) - # For empty tail beyond total_tokens we already clipped block_ends. - last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token - MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) - - return MIN_SEQ_ID, MAX_SEQ_ID - - -def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) - MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) - - cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) - cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) - overlap = cond1 & cond2 - - cond1 = (MIN_Q == MAX_Q).unsqueeze(1) - cond2 = (MIN_K == MAX_K).unsqueeze(0) - same_seq_in_qk = cond1 & cond2 - - full_blocks = overlap & same_seq_in_qk - partial_blocks = overlap & ~same_seq_in_qk - - return full_blocks, partial_blocks - - -def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: - full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) - partial_blocks = partial_blocks[None, None] - full_blocks = full_blocks[None, None] - - q_doc_id = torch.repeat_interleave(SLEN_Q) - k_doc_id = torch.repeat_interleave(SLEN_K) - - def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: - return q_doc_id[q_idx] == k_doc_id[kv_idx] - - total_q_len = q_doc_id.shape[0] - total_k_len = k_doc_id.shape[0] - - return _create_sparse_block_from_block_mask( - (partial_blocks, full_blocks), - doc_mask, - seq_lengths=(total_q_len, total_k_len), - Q_BLOCK_SIZE=128, - KV_BLOCK_SIZE=128, - ) - - -def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: - q_doc_id = torch.repeat_interleave(SLEN_Q) - k_doc_id = torch.repeat_interleave(SLEN_K) - - def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: - return q_doc_id[q_idx] == k_doc_id[kv_idx] - - total_q_len = q_doc_id.shape[0] - total_k_len = k_doc_id.shape[0] - - return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) - - -def varlen_flex_attention_func( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, -) -> torch.Tensor: - batch_size, q_len = query_states.shape[0], query_states.shape[1] - ( - query_states, - key_states, - value_states, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) - - query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() - key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() - value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() - - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] - block_mask = block_mask_creator(seqlens_q, seqlens_k) - - attn_output_unpad = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - enable_gqa=query_states.shape[1] != key_states.shape[1], - ) - - attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) - - return attn_output - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def] - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def] - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - non_pad_indices = sequence_ids != -1 - non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() - sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 - sequence_ids = sequence_ids.flatten()[non_pad_indices] - _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return non_pad_indices, cu_seqlens, max_seqlen_in_batch - - -def _unpad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - q_sequence_ids: torch.Tensor, - k_sequence_ids: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] - assert query_layer.shape[:2] == q_sequence_ids.shape, ( - f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" - ) - assert key_layer.shape[:2] == k_sequence_ids.shape, ( - f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" - ) - assert query_length <= kv_seq_len, ( - f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" - ) - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if torch.equal(q_sequence_ids, k_sequence_ids): - indices_q = indices_k - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - else: - indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) - - query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) - - assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( - f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" - ) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -index_first_axis = IndexFirstAxis.apply -block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask -PAD_TOKEN_ID = 0 - - -def get_tokenizer() -> Tokenizer: - try: - fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") - tokenizer: Tokenizer = Tokenizer.from_file(fname) - except: - print("E1 Tokenizer not found in local directory, downloading from Hugging Face") - from huggingface_hub import hf_hub_download - fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") - tokenizer: Tokenizer = Tokenizer.from_file(fname) - assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( - f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" - ) - - return tokenizer - - -@dataclass -class DataPrepConfig: - max_num_sequences: int = 512 - max_num_positions_within_seq: int = 8192 - remove_X_tokens: bool = False - - -def get_context(sequence: str) -> str | None: - if "," in sequence: - return sequence.rsplit(",", 1)[0] - return None - - -class E1BatchPreparer: - def __init__( - self, - data_prep_config: DataPrepConfig | None = None, - tokenizer: Tokenizer | None = None, - preserve_context_labels: bool = False, - ): - self.tokenizer = tokenizer or get_tokenizer() - self.data_prep_config = data_prep_config or DataPrepConfig() - self.pad_token_id = self.tokenizer.token_to_id("") - self.preserve_context_labels = preserve_context_labels - device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") - self.boundary_token_ids = torch.tensor( - [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device - ).long() - self.mask_token = "?" # nosec - self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) - self.X_token_id = self.tokenizer.token_to_id("X") - self.vocab = self.tokenizer.get_vocab() - - def get_batch_kwargs( # type: ignore[override] - self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False - ) -> dict[str, torch.Tensor | list[str] | list[int]]: - sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] - return self.pad_encodings(sequence_encodings, device, non_blocking) - - def pad_encodings( - self, - sequence_encodings: list[dict[str, torch.Tensor]], - device: torch.device = torch.device("cpu"), - non_blocking: bool = False, - ) -> dict[str, torch.Tensor | list[str] | list[int]]: - non_blocking = non_blocking and device.type == "cuda" - padded_encodings = {} - # Note: We use -1 as the padding value for sequence and position ids because the 0 value - # is a valid value for sequence and position ids. -1 is then used to distinguish valid - # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. - for key, padding_value in { - "input_ids": self.pad_token_id, - "sequence_ids": -1, - "within_seq_position_ids": -1, - "global_position_ids": -1, - "labels": self.pad_token_id, - }.items(): - padded_encodings[key] = pad_sequence( - [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value - ).to(device=device, dtype=torch.long, non_blocking=non_blocking) - - padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] - padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] - - return padded_encodings - - def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: - single_sequences = sequence.split(",") - if len(single_sequences) > self.data_prep_config.max_num_sequences: - raise ValueError( - f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" - " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." - ) - - single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] - - num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] - input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) - labels = torch.cat([x["labels"] for x in single_sequence_encodings]) - - within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) - global_position_ids, ctx_len = [], 0 - for encoding in single_sequence_encodings: - global_position_ids.append(encoding["position_ids"] + ctx_len) - ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) - global_position_ids = torch.cat(global_position_ids) - - sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) - - # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired - context_len = sum(num_tokens[:-1]) - context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) - if not self.preserve_context_labels: - labels[:context_len] = self.pad_token_id - - assert ( - input_ids.shape - == sequence_ids.shape - == within_seq_position_ids.shape - == global_position_ids.shape - == labels.shape - ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" - - assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" - - return { - "input_ids": input_ids, - "sequence_ids": sequence_ids, - "within_seq_position_ids": within_seq_position_ids, - "global_position_ids": global_position_ids, - "labels": labels, - "context": context, - "context_len": context_len, - } - - def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: - if not self.validate_sequence(sequence): - raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") - - if len(sequence) > self.data_prep_config.max_num_positions_within_seq: - raise ValueError( - f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" - ) - - # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` - # but following is faster since our vocabulary is simple. - tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) - position_ids = torch.arange(len(tokens)) - - if self.data_prep_config.remove_X_tokens: - X_positions = torch.where(tokens != self.X_token_id)[0] - tokens = tokens[X_positions] - position_ids = position_ids[X_positions] - - return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} - - def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: - return torch.isin(tokens, self.boundary_token_ids) - - def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: - return tokens == self.mask_token_id - - def validate_sequence(self, sequence: str) -> bool: - assert isinstance(sequence, str), "Sequence must be a string" - sequence = sequence.replace(self.mask_token, "") - return sequence.isalpha() and sequence.isupper() - - - -class E1Config(PretrainedConfig): - model_type = "E1" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( # type: ignore - self, - # Model architecture/initialization - vocab_size=None, - hidden_size=4096, - intermediate_size=16384, - gated_mlp=False, - num_hidden_layers=40, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - rms_norm_eps=1e-5, - initializer_range=0.02, - torch_dtype="bfloat16", - gradient_checkpointing=False, - no_ffn_gradient_checkpointing=False, - # Tokenization - pad_token_id=None, - bos_token_id=None, - eos_token_id=None, - tie_word_embeddings=False, - # Attention implementation & rotary positional embeddings - global_attention_every_n_layers=0, - max_num_sequences=512, - max_num_positions_within_seq=8192, - max_num_positions_global=1024 * 128, - rope_theta_within_seq=10000.0, - rope_theta_global=100000.0, - clip_qkv=None, - **kwargs, - ) -> None: - tokenizer = get_tokenizer() - super().__init__( - pad_token_id=tokenizer.token_to_id(""), - bos_token_id=tokenizer.token_to_id(""), - eos_token_id=tokenizer.token_to_id(""), - tie_word_embeddings=tie_word_embeddings, - torch_dtype=torch_dtype, - **kwargs, - ) - - self.hidden_size = hidden_size - if intermediate_size is None: - intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size - self.intermediate_size = intermediate_size - self.gated_mlp = gated_mlp - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.max_num_positions_within_seq = max_num_positions_within_seq - self.max_num_positions_global = max_num_positions_global - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.rope_theta_within_seq = rope_theta_within_seq - self.rope_theta_global = rope_theta_global - self.max_num_sequences = max_num_sequences - assert clip_qkv is None or clip_qkv > 0 - self.clip_qkv = clip_qkv - self.global_attention_every_n_layers = global_attention_every_n_layers - - self.vocab_size = tokenizer.get_vocab_size() - self.gradient_checkpointing = gradient_checkpointing - self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing - - if vocab_size is not None: - if vocab_size < self.vocab_size: - logger.warning( - f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." - ) - self.vocab_size = vocab_size - elif vocab_size > self.vocab_size: - logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") - self.vocab_size = vocab_size - if pad_token_id is not None and pad_token_id != self.pad_token_id: - logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") - if bos_token_id is not None and bos_token_id != self.bos_token_id: - logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") - if eos_token_id is not None and eos_token_id != self.eos_token_id: - logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") - - -class DynamicCache: - """ - A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. - It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. - - Args: - key_cache (`list[torch.Tensor]`): The list of key states. - value_cache (`list[torch.Tensor]`): The list of value states. - """ - - def __init__(self) -> None: - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def update( - self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Update the key and value caches in-place, and return the necessary keys and value states. - - Args: - key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] - value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] - layer_idx (`int`): The index of the layer to update. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. - """ - # Lazy initialization - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: int = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 - return layer_seq_length - - def crop(self, max_length: int) -> None: - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - assert max_length > 0, "max_length must be positive" - - if self.get_seq_length() <= max_length: - return - - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] - - def batch_repeat_interleave(self, repeats: int) -> None: - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor) -> None: - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - - -class KVCache: - def __init__(self, cache_size: int = 4) -> None: - self.cache_size = cache_size - self.tensor_input_field_names = [ - "input_ids", - "within_seq_position_ids", - "global_position_ids", - "sequence_ids", - "labels", - ] - self.tensor_output_field_names = ["logits", "embeddings"] - self.cache_dict: dict[str, DynamicCache] = {} - self.cache_queue: list[str] = [] - - def reset(self) -> None: - for k in list(self.cache_dict.keys()): - del self.cache_dict[k] - del self.cache_dict - self.cache_dict = {} - self.cache_queue = [] - - torch.cuda.empty_cache() - - def before_forward(self, batch: dict[str, torch.Tensor]) -> None: - contexts: list[str] | None = batch.get("context", None) - if contexts is None or "context_len" not in batch: - logger.warning_once( - "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." - ) - return - - context_lens: list[int] = list(set(batch["context_len"])) - contexts: list[str] = list(set(contexts)) # type: ignore[no-redef] - if len(contexts) != 1 or len(context_lens) != 1: - logger.warning( - "SingleContextKVCache requires a single context and context length. " - "Multiple contexts or context lengths found in a single batch. Skipping." - ) - return - - batch_size = batch["input_ids"].shape[0] - - unique_context = contexts[0] - unique_context_len = context_lens[0] - batch["use_cache"] = True - - if unique_context not in self.cache_dict: - return - - self.cache_dict[unique_context].batch_repeat_interleave(batch_size) - past_key_values = self.cache_dict[unique_context] - batch["past_key_values"] = past_key_values - - # Remove context from the input fields - for field_name in self.tensor_input_field_names: - if batch.get(field_name, None) is not None: - batch[field_name] = batch[field_name][:, unique_context_len:] - - def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None: - contexts = batch.get("context", None) - context_lens = batch.get("context_len", []) - if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: - return - - assert batch["use_cache"] - unique_context = contexts[0] - unique_context_len = context_lens[0] - - past_key_values = getattr(outputs, "past_key_values", None) - if not isinstance(past_key_values, DynamicCache): - logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") - return - - if "past_key_values" not in batch: - if len(self.cache_queue) == self.cache_size: - last_context = self.cache_queue.pop(0) - if last_context not in self.cache_queue: - del self.cache_dict[last_context] - torch.cuda.empty_cache() - - self.cache_dict[unique_context] = past_key_values - self.cache_queue.append(unique_context) - - # Remove context from the input fields - for field_name in self.tensor_input_field_names: - if field_name in batch and batch[field_name] is not None: - batch[field_name] = batch[field_name][:, unique_context_len:] - - # Remove context from the output fields - for field_name in self.tensor_output_field_names: - if field_name in outputs and outputs[field_name] is not None: - outputs[field_name] = outputs[field_name][:, unique_context_len:] - if "hidden_states" in outputs and outputs["hidden_states"] is not None: - outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] - - self.cache_dict[unique_context].crop(unique_context_len) - self.cache_dict[unique_context].batch_select_indices([0]) - - -class AttentionMethod(Enum): - FLASH = "flash" - FLEX = "flex" - - -class AttentionLayerType(Enum): - WITHIN_SEQ = "within_seq" - GLOBAL = "global" - - -class AttentionArgs(TypedDict, total=False): - flex_attention_args: FlexAttentionArgs - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, - num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class RotaryPositionalEmbedding(nn.Module): - def __init__( - self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None - ): - super().__init__() - - self.dim = dim - self.base = base - self.max_position_embeddings = max_position_embeddings - inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) - - @staticmethod - def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: - # Different from paper, but it uses a different permutation in order to obtain the same calculation - self.max_seq_len_cached = seq_len - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - angles = torch.outer(t, self.inv_freq.to(device)) - angles = torch.cat((angles, angles), dim=1) - self.register_buffer("cos_cached", angles.cos(), persistent=False) - self.register_buffer("sin_cached", angles.sin(), persistent=False) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - # x: [bsz, seq_len, num_attention_heads, head_size] - device, dtype = q.device, q.dtype - seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len - - if seq_len > self.max_seq_len_cached: - self._set_sin_cos_cache(seq_len=seq_len, device=device) - - # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), - # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). - idxs = position_ids.to(device) - cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] - sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] - - # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is - # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is - # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed - - -class Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper.""" - - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_kv_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_kv_heads - self.max_num_seqs = config.max_num_sequences - self.clip_qkv = config.clip_qkv - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - if self.config.global_attention_every_n_layers > 0: - self.layer_type = ( - AttentionLayerType.GLOBAL - if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 - else AttentionLayerType.WITHIN_SEQ - ) - else: - self.layer_type = AttentionLayerType.WITHIN_SEQ - - self.rope_theta = ( - config.rope_theta_within_seq - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else config.rope_theta_global - ) - self.max_position_embeddings = ( - config.max_num_positions_within_seq - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else config.max_num_positions_global - ) - - self.rotary_emb = RotaryPositionalEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta - ) - - def prepare_qkv( - self, - hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - past_key_value: DynamicCache | None = None, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, q_len, _ = hidden_states.size() - query_states: torch.Tensor = self.q_proj(hidden_states) - key_states: torch.Tensor = self.k_proj(hidden_states) - val_states: torch.Tensor = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) - val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) - - if self.clip_qkv is not None: - query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) - key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) - val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) - - query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) - - if use_cache and past_key_value is not None: - key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.q_proj.weight.dtype - if input_dtype != target_dtype: - logger.warning_once( - f"The input hidden states seems to be silently casted in {input_dtype}. " - f"This might be because you have upcasted embedding or layer norm layers " - f"in {input_dtype}. We will cast back the input in {target_dtype}." - ) - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - val_states = val_states.to(target_dtype) - - return query_states, key_states, val_states - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: - is_cache_prefilled = ( - use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 - ) - - query_states, key_states, val_states = self.prepare_qkv( - hidden_states=hidden_states, - position_ids=within_seq_position_ids - if self.layer_type == AttentionLayerType.WITHIN_SEQ - else global_position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - ) - - # Note: We fallback to using flash attention in inference mode when cache is filled with kv values - # for global attention layers instead of flex attention. This is because once the cache is filled, - # the last sequence attends to everything in the cache, so we can make things faster by using a - # bidirectional flash attention instead of block-causal flex attention. - if self.layer_type == AttentionLayerType.WITHIN_SEQ or is_cache_prefilled: - attention_type = AttentionMethod.FLASH - else: - attention_type = AttentionMethod.FLEX - - attn_output, attn_weights = self._attn( - attention_type=attention_type, - query_states=query_states, - key_states=key_states, - val_states=val_states, - sequence_ids=sequence_ids, - attention_args=attention_args, - output_attentions=output_attentions, - ) - - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value - - def _attn( - self, - attention_type: AttentionMethod, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - match attention_type: - case AttentionMethod.FLASH: - f = self._flash_attn - case AttentionMethod.FLEX: - f = self._flex_attn - case _: - raise ValueError(f"No attention implementation found for {attention_type}") - return f( - query_states=query_states, - key_states=key_states, - val_states=val_states, - sequence_ids=sequence_ids, - attention_args=attention_args, - output_attentions=output_attentions, - ) - - def _flash_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Flash attention implementation. - - Calls the public API of flash attention and deals with padding tokens if any are present. - """ - assert not output_attentions, "Flash attention doesn't support returning attention masks" - bsz, q_len = query_states.shape[0], query_states.shape[1] - _, kv_len = key_states.shape[0], key_states.shape[1] - - if self.layer_type == AttentionLayerType.GLOBAL: # Only happens in inference - q_sequence_ids = sequence_ids - if q_len < kv_len: - # Assumes query contain only one sequence - # and all tokens in query (except padding) will attend to all tokens in KV - first_token_id = sequence_ids[:, 0].unsqueeze(1) - k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) - else: - k_sequence_ids = sequence_ids - else: - if q_len < kv_len: # Only happens in inference - key_states = key_states[:, -q_len:] - val_states = val_states[:, -q_len:] - q_sequence_ids = k_sequence_ids = sequence_ids - - if is_flash_attention_available(): - attn_output = flash_attention_func( - query_states, - key_states, - val_states, - q_sequence_ids=q_sequence_ids, - k_sequence_ids=k_sequence_ids, - causal=False, - ) - else: - attn_output = varlen_flex_attention_func( - query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - return attn_output, None - - def _flex_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - val_states: torch.Tensor, - sequence_ids: torch.Tensor, - attention_args: AttentionArgs | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - bsz, q_len = query_states.shape[0], query_states.shape[1] - flex_attention_args = attention_args.get("flex_attention_args", None) if attention_args is not None else None - block_mask = flex_attention_args.get("block_mask", None) if flex_attention_args is not None else None - score_mod = flex_attention_args.get("score_mod", None) if flex_attention_args is not None else None - outputs = flex_attention_func(query_states, key_states, val_states, score_mod=score_mod, block_mask=block_mask) - - outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() - return outputs, None - - -class MLP(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.w2(self.act_fn(self.w1(hidden_states))) - - -class GLUMLP(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - hidden_states = self.w2(hidden_states) - return hidden_states - - -class FFN(nn.Module): - def __init__(self, config: E1Config): - super().__init__() - mlp_cls = GLUMLP if config.gated_mlp else MLP - self.mlp = mlp_cls(config) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.mlp(hidden_states) - - -@dataclass -class E1ModelOutputWithPast(ModelOutput): - """Base class for model's outputs, with potential hidden states and attentions. - - Attributes: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -class E1MaskedLMOutputWithPast(ModelOutput): - loss: torch.FloatTensor | None = None - mlm_loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -class E1ClassificationOutputWithPast(ModelOutput): - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - past_key_values: DynamicCache | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -class RMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - self.hidden_size = hidden_size - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - if layer_norm is None: - return torch.nn.functional.rms_norm( - hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon - ).to(input_dtype) - else: - return layer_norm.rms_norm_fn( - x=hidden_states, - weight=self.weight, - bias=None, # no bias - residual=None, - eps=self.variance_epsilon, - dropout_p=0.0, # no dropout by default - prenorm=False, - residual_in_fp32=False, - ).to(input_dtype) - - -class NormAttentionNorm(nn.Module): - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.self_attn = Attention(config, layer_idx) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - return hidden_states, residual, self_attn_weights, present_key_value - - -class DecoderLayer(nn.Module): - def __init__(self, config: E1Config, layer_idx: int): - super().__init__() - self.initializer_range = config.initializer_range - self.hidden_size = config.hidden_size - self.norm_attn_norm = NormAttentionNorm(config, layer_idx) - self.ffn = FFN(config) - - def forward( - self, - hidden_states: torch.Tensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - attention_args: AttentionArgs | None = None, - past_key_value: DynamicCache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: - hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm( - hidden_states=hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - # Fully Connected - hidden_states = self.ffn(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, self_attn_weights, present_key_value - - -### Support for embedding datasets with low code -class Pooler: - def __init__(self, pooling_types: List[str]): - self.pooling_types = pooling_types - self.pooling_options = { - 'mean': self.mean_pooling, - 'max': self.max_pooling, - 'norm': self.norm_pooling, - 'median': self.median_pooling, - 'std': self.std_pooling, - 'var': self.var_pooling, - 'cls': self.cls_pooling, - 'parti': self._pool_parti, - } - - def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: - maxed_attentions = torch.max(attentions, dim=1)[0] - return maxed_attentions - - def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): - # Run PageRank on the attention matrix converted to a graph. - # Raises exceptions if the graph doesn't match the token sequence or has no edges. - # Returns the PageRank scores for each token node. - G = self._convert_to_graph(attention_matrix) - if G.number_of_nodes() != attention_matrix.shape[0]: - raise Exception( - f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") - if G.number_of_edges() == 0: - raise Exception(f"You don't seem to have any attention edges left in the graph.") - - return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) - - def _convert_to_graph(self, matrix): - # Convert a matrix (e.g., attention scores) to a directed graph using networkx. - # Each element in the matrix represents a directed edge with a weight. - G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) - return G - - def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): - # Remove keys where attention_mask is 0 - if attention_mask is not None: - for k in list(dict_importance.keys()): - if attention_mask[k] == 0: - del dict_importance[k] - - #dict_importance[0] # remove cls - #dict_importance[-1] # remove eos - total = sum(dict_importance.values()) - return np.array([v / total for _, v in dict_importance.items()]) - - def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) - maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() - # emb is (b, L, d), maxed_attentions is (b, L, L) - emb_pooled = [] - for e, a, mask in zip(emb, maxed_attentions, attention_mask): - dict_importance = self._page_rank(a) - importance_weights = self._calculate_importance_weights(dict_importance, mask) - num_tokens = int(mask.sum().item()) - emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) - pooled = torch.tensor(np.array(emb_pooled)) - return pooled - - def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.mean(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - - def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.max(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).max(dim=1).values - - def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.norm(dim=1, p=2) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).norm(dim=1, p=2) - - def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.median(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).median(dim=1).values - - def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.std(dim=1) - else: - # Compute variance correctly over non-masked positions, then take sqrt - var = self.var_pooling(emb, attention_mask, **kwargs) - return torch.sqrt(var) - - def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.var(dim=1) - else: - # Correctly compute variance over only non-masked positions - attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) - # Compute mean over non-masked positions - mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - mean = mean.unsqueeze(1) # (b, 1, d) - # Compute squared differences from mean, only over non-masked positions - squared_diff = (emb - mean) ** 2 # (b, L, d) - # Sum squared differences over non-masked positions and divide by count - var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - return var - - def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - return emb[:, 0, :] - - def __call__( - self, - emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attentions: Optional[torch.Tensor] = None - ): # [mean, max] - final_emb = [] - for pooling_type in self.pooling_types: - final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) - return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) - - -class EmbeddingMixin: - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return next(self.parameters()).device - - def _read_sequences_from_db(self, db_path: str) -> set[str]: - """Read sequences from SQLite database.""" - import sqlite3 - sequences = [] - with sqlite3.connect(db_path) as conn: - c = conn.cursor() - c.execute("SELECT sequence FROM embeddings") - while True: - row = c.fetchone() - if row is None: - break - sequences.append(row[0]) - return set(sequences) - - def embed_dataset( - self, - sequences: List[str], - #tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed - batch_size: int = 2, - max_len: int = 512, - truncate: bool = True, - full_embeddings: bool = False, - embed_dtype: torch.dtype = torch.float32, - pooling_types: List[str] = ['mean'], - sql: bool = False, - save: bool = True, - sql_db_path: str = 'embeddings.db', - save_path: str = 'embeddings.pth', - **kwargs, - ) -> Optional[dict[str, torch.Tensor]]: - """Embed a dataset of protein sequences. - - Args: - sequences: List of protein sequences - batch_size: Batch size for processing - max_len: Maximum sequence length - full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) - pooling_type: Type of pooling ('mean' or 'cls') - sql: Whether to store embeddings in SQLite database - will be stored in float32 - sql_db_path: Path to SQLite database - - Returns: - Dictionary mapping sequences to embeddings, or None if sql=True - - Note: - - If sql=True, embeddings can only be stored in float32 - - sql is ideal if you need to stream a very large dataset for training in real-time - - save=True is ideal if you can store the entire embedding dictionary in RAM - - sql will be used if it is True and save is True or False - - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences - - Sequences will be truncated to max_len and sorted by length in descending order for faster processing - - Example: - >>> embedder = EmbeddingMixin() - >>> embedding_dict = embedder.embed_dataset( - sequences=[ - 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences - ], - batch_size=2, # adjust for your GPU memory - max_len=512, # adjust for your needs - full_embeddings=False, # if True, no pooling is performed - embed_dtype=torch.float32, # cast to what dtype you want - pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together - sql=False, # if True, embeddings will be stored in SQLite database - sql_db_path='embeddings.db', - save=True, # if True, embeddings will be saved as a .pth file - save_path='embeddings.pth', - ) - >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql - """ - sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) - sequences = sorted(sequences, key=len, reverse=True) - hidden_size = self.config.hidden_size - pooler = Pooler(pooling_types) if not full_embeddings else None - - def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings - return residue_embeddings - else: - return pooler(residue_embeddings, attention_mask) - - if sql: - import sqlite3 - conn = sqlite3.connect(sql_db_path) - c = conn.cursor() - c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') - already_embedded = self._read_sequences_from_db(sql_db_path) - to_embed = [seq for seq in sequences if seq not in already_embedded] - print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") - print(f"Embedding {len(to_embed)} new sequences") - if len(to_embed) > 0: - with torch.no_grad(): - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - input_ids, attention_mask = self._embed(seqs, return_attention_mask=True) - embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32 - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes())) - conn.commit() - conn.commit() - conn.close() - return None - - embeddings_dict = {} - if os.path.exists(save_path): - embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True) - to_embed = [seq for seq in sequences if seq not in embeddings_dict] - print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") - print(f"Embedding {len(to_embed)} new sequences") - else: - to_embed = sequences - print(f"Embedding {len(to_embed)} new sequences") - - if len(to_embed) > 0: - with torch.no_grad(): - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True) - embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - embeddings_dict[seq] = emb.cpu() - - if save: - torch.save(embeddings_dict, save_path) - - return embeddings_dict - - -class E1PreTrainedModel(PreTrainedModel): - config_class = E1Config - config: E1Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DecoderLayer"] - _transformer_layer_cls = [DecoderLayer] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module: nn.Module) -> None: - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - def post_init(self) -> None: - super().post_init() - - def _backward_compatibility_gradient_checkpointing(self) -> None: - if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): - self.gradient_checkpointing_enable(dict(use_reentrant=False)) - - @property - def _device(self) -> torch.device: - return next(self.parameters()).device - - @classmethod - def from_pretrained( # type: ignore[no-untyped-def] - cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs - ) -> "E1PreTrainedModel": - return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - - -class E1Model(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) - self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = E1BatchPreparer() - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embed_tokens - - def set_input_embeddings(self, value: nn.Embedding) -> None: - self.embed_tokens = value - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - # Ignore copy - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - **kwargs - ) -> E1ModelOutputWithPast: - """ - Args: - input_ids: (batch_size, seq_length) - within_seq_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the sequence itself. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] - global_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the global sequence. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] - sequence_ids: (batch_size, seq_length) - This tensor contains the sequence id of each residue. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] - past_key_values: DynamicCache - use_cache: bool - output_attentions: bool - output_hidden_states: bool - - Returns: - E1ModelOutputWithPast: Model Outputs - """ - batch_size, seq_length = input_ids.shape - - if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - elif not use_cache: - # To avoid weirdness with gradient checkpointing: https://github.com/huggingface/transformers/issues/28499 - past_key_values = None - - global_position_ids = global_position_ids.view(-1, seq_length).long() - within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() - sequence_ids = sequence_ids.view(-1, seq_length).long() - - max_position_id = torch.max(within_seq_position_ids).item() - min_position_id = torch.min(within_seq_position_ids).item() - assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( - f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" - ) - - inputs_embeds = self.embed_tokens(input_ids) - # -1 is used to indicate padding tokens, so we need to clamp the sequence ids to 0 - inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) - - # In case we need to do any manual typecasting - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype - hidden_states = inputs_embeds.to(target_dtype) - - # (batch_size, query_length, keyval_length) - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - - # Create block mask for flex attention - attention_args: AttentionArgs | None = None - if past_key_values_length == 0: - block_mask = create_block_causal_mask_optimized(sequence_ids) - flex_attention_args = FlexAttentionArgs(block_mask=block_mask) - attention_args = AttentionArgs(flex_attention_args=flex_attention_args) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore[operator] - - if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - within_seq_position_ids, - global_position_ids, - sequence_ids, - attention_args, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - attention_args=attention_args, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states, self_attn_weights, present_key_value = layer_outputs - - if use_cache: - # NOTE: it's necessary to re-assign past_key_values because FSDP2 - # passes certain arguments by value, not by reference. - # See https://github.com/huggingface/transformers/issues/38190#issuecomment-2914016168 - next_decoder_cache = past_key_values = present_key_value - - if output_attentions: - all_self_attns += (self_attn_weights,) # type: ignore[operator] - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore[operator] - - next_cache = next_decoder_cache if use_cache else None - - return E1ModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: E1Model = E1Model(config) - self.vocab_size = config.vocab_size - self.mlm_head = torch.nn.Sequential( - nn.Linear(config.hidden_size, config.hidden_size, bias=True), - nn.GELU(), - nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), - nn.Linear(config.hidden_size, config.vocab_size, bias=True), - ) - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = E1BatchPreparer() - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - **kwargs, - ) -> E1MaskedLMOutputWithPast: - """ - Args: - input_ids: (batch_size, seq_length) - within_seq_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the sequence itself. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] - global_position_ids: (batch_size, seq_length) - This tensor contains the position of each residue within the global sequence. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] - sequence_ids: (batch_size, seq_length) - This tensor contains the sequence id of each residue. - For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], - the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] - labels: (batch_size, seq_length) - past_key_values: DynamicCache - use_cache: bool - output_attentions: bool - output_hidden_states: bool - - Returns: - E1MaskedLMOutputWithPast: Model Outputs - """ - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - x = outputs.last_hidden_state - loss = None - - # Compute masked language modeling loss - mlm_logits = self.mlm_head(x).float() - mlm_loss = 0.0 - if labels is not None: - mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) - mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) - mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") - mask = mlm_labels_flat != self.model.padding_idx - n_mlm = mask.sum() - mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) - loss = 0.0 - loss += mlm_loss - - return E1MaskedLMOutputWithPast( - loss=loss, - mlm_loss=mlm_loss, - logits=mlm_logits, - last_hidden_state=x, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: E1Model = E1Model(config) - self.vocab_size = config.vocab_size - self.num_labels = config.num_labels - self.classifier = nn.Sequential( - nn.Linear(config.hidden_size * 2, config.hidden_size * 4), - nn.GELU(), - nn.LayerNorm(config.hidden_size * 4), - nn.Linear(config.hidden_size * 4, config.num_labels), - ) - self.mse = nn.MSELoss() - self.ce = nn.CrossEntropyLoss() - self.bce = nn.BCEWithLogitsLoss() - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = E1BatchPreparer() - - if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: - pooling_types = kwargs['pooling_types'] - else: - pooling_types = ['mean', 'var'] - self.pooler = Pooler(pooling_types) - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - **kwargs, - ) -> E1ClassificationOutputWithPast: - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - attention_mask = (sequence_ids != -1).long() - x = outputs.last_hidden_state - features = self.pooler(x, attention_mask) - logits = self.classifier(features) - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - if self.num_labels == 1: - loss = self.mse(logits.flatten(), labels.flatten()) - else: - loss = self.mse(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss = self.bce(logits, labels) - - return E1ClassificationOutputWithPast( - loss=loss, - logits=logits, - last_hidden_state=x, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): - config: E1Config - config_class = E1Config - def __init__(self, config: E1Config, **kwargs): - E1PreTrainedModel.__init__(self, config, **kwargs) - self.model: E1Model = E1Model(config) - self.vocab_size = config.vocab_size - self.num_labels = config.num_labels - self.classifier = nn.Sequential( - nn.Linear(config.hidden_size * 2, config.hidden_size * 4), - nn.GELU(), - nn.LayerNorm(config.hidden_size * 4), - nn.Linear(config.hidden_size * 4, config.num_labels), - ) - self.loss_fct = nn.CrossEntropyLoss() - self.gradient_checkpointing = config.gradient_checkpointing - self.prep_tokens = E1BatchPreparer() - self.post_init() - - @property - def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: - return self.model.device_mesh - - @torch.inference_mode() - def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: - batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) - last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - if return_attention_mask: - attention_mask = (batch['sequence_ids'] != -1).long() - return last_hidden_state, attention_mask - else: - return last_hidden_state - - def forward( - self, - input_ids: torch.LongTensor, - within_seq_position_ids: torch.LongTensor, - global_position_ids: torch.LongTensor, - sequence_ids: torch.LongTensor, - labels: torch.LongTensor | None = None, - past_key_values: DynamicCache | None = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - **kwargs, - ) -> E1ClassificationOutputWithPast: - outputs: E1ModelOutputWithPast = self.model( - input_ids=input_ids, - within_seq_position_ids=within_seq_position_ids, - global_position_ids=global_position_ids, - sequence_ids=sequence_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - x = outputs.last_hidden_state - logits = self.classifier(x) - loss = None - if labels is not None: - loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return E1ClassificationOutputWithPast( - loss=loss, - logits=logits, - last_hidden_state=x, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -if __name__ == "__main__": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = E1ForSequenceClassification.from_pretrained("Profluent-Bio/E1-150m", dtype=torch.bfloat16, num_labels=1).eval().to(device) - print(model) - - seqs = [ - "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE", - "IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY", - "PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL", - "SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA", - ] - - batch = model.prep_tokens.get_batch_kwargs(seqs, device=device) - batch['labels'] = torch.tensor([0.0, 0.0, 0.0, 0.0], device=device) - - last_hidden_state = model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state - print(last_hidden_state.shape) +import os +os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + +import numpy as np +import networkx as nx +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from einops import rearrange, repeat +from enum import Enum +from typing import Any, TypedDict, Callable, Optional, List +from dataclasses import dataclass +from tokenizers import Tokenizer +from transformers import PretrainedConfig, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput +from transformers.utils import logging +from tqdm.auto import tqdm +from pooler import EmbeddingMixin, Pooler + + +logger = logging.get_logger(__name__) + +### Establish attention compatibility +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + logger.warning("Failed to import flash attention; Will be using PyTorch attention instead") + flash_attn_func = None + flash_attn_varlen_func = None + +try: + from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, + _create_sparse_block_from_block_mask + ) + + if torch.cuda.is_available(): + # if on linux, compile the flex attention function + if os.name == 'posix': + print("Compiling flex attention") + flex_attention = torch.compile(flex_attention, dynamic=True) + else: + print("Not compiling flex attention, detected non-Linux environment") + +except ImportError: + logger.warning("Failed to import flex attention; Will be using PyTorch attention instead") + flex_attention = None + +try: + from kernels import get_kernel + layer_norm = get_kernel("kernels-community/triton-layer-norm") +except Exception as e: + logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") + layer_norm = None + + +def is_flash_attention_available() -> bool: + return ( + flash_attn_func is not None and flash_attn_varlen_func is not None and (os.getenv("USE_FLASH_ATTN", "1") == "1") + ) + + +class FlexAttentionArgs(TypedDict, total=False): + block_mask: BlockMask | None + score_mod: Callable | None + + +def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: + # Assumes sequence_ids is sorted in increasing order for each batch item, except for + # the -1 values, which are used to indicate the padding tokens. + def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] + return ( + (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) + & (sequence_ids[b, q_idx] != -1) + & (sequence_ids[b, kv_idx] != -1) + ) + + batch_size, seqlen = sequence_ids.shape + return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) + + +def flex_attention_func( + query_states: torch.Tensor, # (bs, seqlen, nh, hs) + key_states: torch.Tensor, # (bs, seqlen, nkv, hs) + value_states: torch.Tensor, # (bs, seqlen, nkv, hs) + score_mod: Callable | None = None, + block_mask: BlockMask | None = None, +) -> torch.Tensor: + assert flex_attention is not None, "Flex Attention is not available in this environment" + assert score_mod is None, "Score mod is not supported yet" + query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) + key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) + value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) + + outputs = flex_attention( + query_states, + key_states, + value_states, + block_mask=block_mask, + score_mod=score_mod, + enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh + ) + + outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) + return outputs + + +def flash_attention_func( + query_states: torch.Tensor, # (bs, seqlen, nh, hs) + key_states: torch.Tensor, # (bs, seqlen, nkv, hs) + value_states: torch.Tensor, # (bs, seqlen, nkv, hs) + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, + causal: bool = False, +) -> torch.Tensor: # (bs, seqlen, nh, hs) + # Contains at least one padding token in the sequence. Note: ignore attention mask if causal. + if not is_flash_attention_available(): + raise ImportError("Flash Attention is not available. Please install flash-attn.") + + if not causal: + batch_size, q_len = query_states.shape[0], query_states.shape[1] + ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + causal=False, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + + else: + attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) + + return attn_output + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def] + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( + -1, *other_shape + ) + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def] + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + device = SLEN.device + total_tokens = torch.sum(SLEN) + B = (total_tokens + block_size - 1) // block_size + padding_tokens = B * block_size - total_tokens + SLEN = torch.cat([SLEN, torch.Tensor([padding_tokens]).to(device)], dim=0) + + assert torch.sum(SLEN) == B * block_size + + # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i + cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) + total_tokens = cum[-1].item() + + # Block start/end offsets [start, end) in token index space + block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) + block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) + + # MIN_SEQ_ID[i] = first sequence whose end > block_start + # searchsorted with right=True returns first index where cum > value + MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) + + # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) + # For empty tail beyond total_tokens we already clipped block_ends. + last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token + MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) + + return MIN_SEQ_ID, MAX_SEQ_ID + + +def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) + MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) + + cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) + cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) + overlap = cond1 & cond2 + + cond1 = (MIN_Q == MAX_Q).unsqueeze(1) + cond2 = (MIN_K == MAX_K).unsqueeze(0) + same_seq_in_qk = cond1 & cond2 + + full_blocks = overlap & same_seq_in_qk + partial_blocks = overlap & ~same_seq_in_qk + + return full_blocks, partial_blocks + + +def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: + full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) + partial_blocks = partial_blocks[None, None] + full_blocks = full_blocks[None, None] + + q_doc_id = torch.repeat_interleave(SLEN_Q) + k_doc_id = torch.repeat_interleave(SLEN_K) + + def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return q_doc_id[q_idx] == k_doc_id[kv_idx] + + total_q_len = q_doc_id.shape[0] + total_k_len = k_doc_id.shape[0] + + return _create_sparse_block_from_block_mask( + (partial_blocks, full_blocks), + doc_mask, + seq_lengths=(total_q_len, total_k_len), + Q_BLOCK_SIZE=128, + KV_BLOCK_SIZE=128, + ) + + +def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: + q_doc_id = torch.repeat_interleave(SLEN_Q) + k_doc_id = torch.repeat_interleave(SLEN_K) + + def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return q_doc_id[q_idx] == k_doc_id[kv_idx] + + total_q_len = q_doc_id.shape[0] + total_k_len = k_doc_id.shape[0] + + return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) + + +def varlen_flex_attention_func( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, +) -> torch.Tensor: + batch_size, q_len = query_states.shape[0], query_states.shape[1] + ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) + + query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() + key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() + value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() + + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + block_mask = block_mask_creator(seqlens_q, seqlens_k) + + attn_output_unpad = flex_attention( + query_states, + key_states, + value_states, + block_mask=block_mask, + enable_gqa=query_states.shape[1] != key_states.shape[1], + ) + + attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) + + return attn_output + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def] + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def] + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + non_pad_indices = sequence_ids != -1 + non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() + sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 + sequence_ids = sequence_ids.flatten()[non_pad_indices] + _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return non_pad_indices, cu_seqlens, max_seqlen_in_batch + + +def _unpad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_sequence_ids: torch.Tensor, + k_sequence_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] + assert query_layer.shape[:2] == q_sequence_ids.shape, ( + f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" + ) + assert key_layer.shape[:2] == k_sequence_ids.shape, ( + f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" + ) + assert query_length <= kv_seq_len, ( + f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if torch.equal(q_sequence_ids, k_sequence_ids): + indices_q = indices_k + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + else: + indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) + + query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) + + assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( + f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +index_first_axis = IndexFirstAxis.apply +block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask +PAD_TOKEN_ID = 0 + + +def get_tokenizer() -> Tokenizer: + try: + fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") + tokenizer: Tokenizer = Tokenizer.from_file(fname) + except: + print("E1 Tokenizer not found in local directory, downloading from Hugging Face") + from huggingface_hub import hf_hub_download + fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") + tokenizer: Tokenizer = Tokenizer.from_file(fname) + assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( + f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" + ) + + return tokenizer + + +@dataclass +class DataPrepConfig: + max_num_sequences: int = 512 + max_num_positions_within_seq: int = 8192 + remove_X_tokens: bool = False + + +def get_context(sequence: str) -> str | None: + if "," in sequence: + return sequence.rsplit(",", 1)[0] + return None + + +class E1BatchPreparer: + def __init__( + self, + data_prep_config: DataPrepConfig | None = None, + tokenizer: Tokenizer | None = None, + preserve_context_labels: bool = False, + ): + self.tokenizer = tokenizer or get_tokenizer() + self.data_prep_config = data_prep_config or DataPrepConfig() + self.pad_token_id = self.tokenizer.token_to_id("") + self.preserve_context_labels = preserve_context_labels + device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + self.boundary_token_ids = torch.tensor( + [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device + ).long() + self.mask_token = "?" # nosec + self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) + self.X_token_id = self.tokenizer.token_to_id("X") + self.vocab = self.tokenizer.get_vocab() + + def get_batch_kwargs( # type: ignore[override] + self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False + ) -> dict[str, torch.Tensor | list[str] | list[int]]: + sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] + return self.pad_encodings(sequence_encodings, device, non_blocking) + + def pad_encodings( + self, + sequence_encodings: list[dict[str, torch.Tensor]], + device: torch.device = torch.device("cpu"), + non_blocking: bool = False, + ) -> dict[str, torch.Tensor | list[str] | list[int]]: + non_blocking = non_blocking and device.type == "cuda" + padded_encodings = {} + # Note: We use -1 as the padding value for sequence and position ids because the 0 value + # is a valid value for sequence and position ids. -1 is then used to distinguish valid + # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. + for key, padding_value in { + "input_ids": self.pad_token_id, + "sequence_ids": -1, + "within_seq_position_ids": -1, + "global_position_ids": -1, + "labels": self.pad_token_id, + }.items(): + padded_encodings[key] = pad_sequence( + [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value + ).to(device=device, dtype=torch.long, non_blocking=non_blocking) + + padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] + padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] + + return padded_encodings + + def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: + single_sequences = sequence.split(",") + if len(single_sequences) > self.data_prep_config.max_num_sequences: + raise ValueError( + f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" + " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." + ) + + single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] + + num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] + input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) + labels = torch.cat([x["labels"] for x in single_sequence_encodings]) + + within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) + global_position_ids, ctx_len = [], 0 + for encoding in single_sequence_encodings: + global_position_ids.append(encoding["position_ids"] + ctx_len) + ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) + global_position_ids = torch.cat(global_position_ids) + + sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) + + # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired + context_len = sum(num_tokens[:-1]) + context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) + if not self.preserve_context_labels: + labels[:context_len] = self.pad_token_id + + assert ( + input_ids.shape + == sequence_ids.shape + == within_seq_position_ids.shape + == global_position_ids.shape + == labels.shape + ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" + + assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" + + return { + "input_ids": input_ids, + "sequence_ids": sequence_ids, + "within_seq_position_ids": within_seq_position_ids, + "global_position_ids": global_position_ids, + "labels": labels, + "context": context, + "context_len": context_len, + } + + def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: + if not self.validate_sequence(sequence): + raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") + + if len(sequence) > self.data_prep_config.max_num_positions_within_seq: + raise ValueError( + f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" + ) + + # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` + # but following is faster since our vocabulary is simple. + tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) + position_ids = torch.arange(len(tokens)) + + if self.data_prep_config.remove_X_tokens: + X_positions = torch.where(tokens != self.X_token_id)[0] + tokens = tokens[X_positions] + position_ids = position_ids[X_positions] + + return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} + + def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: + return torch.isin(tokens, self.boundary_token_ids) + + def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: + return tokens == self.mask_token_id + + def validate_sequence(self, sequence: str) -> bool: + assert isinstance(sequence, str), "Sequence must be a string" + sequence = sequence.replace(self.mask_token, "") + return sequence.isalpha() and sequence.isupper() + + + +class E1Config(PretrainedConfig): + model_type = "E1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( # type: ignore + self, + # Model architecture/initialization + vocab_size=None, + hidden_size=4096, + intermediate_size=16384, + gated_mlp=False, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + rms_norm_eps=1e-5, + initializer_range=0.02, + torch_dtype="bfloat16", + gradient_checkpointing=False, + no_ffn_gradient_checkpointing=False, + # Tokenization + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + # Attention implementation & rotary positional embeddings + global_attention_every_n_layers=0, + max_num_sequences=512, + max_num_positions_within_seq=8192, + max_num_positions_global=1024 * 128, + rope_theta_within_seq=10000.0, + rope_theta_global=100000.0, + clip_qkv=None, + **kwargs, + ) -> None: + tokenizer = get_tokenizer() + super().__init__( + pad_token_id=tokenizer.token_to_id(""), + bos_token_id=tokenizer.token_to_id(""), + eos_token_id=tokenizer.token_to_id(""), + tie_word_embeddings=tie_word_embeddings, + torch_dtype=torch_dtype, + **kwargs, + ) + + self.hidden_size = hidden_size + if intermediate_size is None: + intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size + self.intermediate_size = intermediate_size + self.gated_mlp = gated_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_num_positions_within_seq = max_num_positions_within_seq + self.max_num_positions_global = max_num_positions_global + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta_within_seq = rope_theta_within_seq + self.rope_theta_global = rope_theta_global + self.max_num_sequences = max_num_sequences + assert clip_qkv is None or clip_qkv > 0 + self.clip_qkv = clip_qkv + self.global_attention_every_n_layers = global_attention_every_n_layers + + self.vocab_size = tokenizer.get_vocab_size() + self.gradient_checkpointing = gradient_checkpointing + self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing + + if vocab_size is not None: + if vocab_size < self.vocab_size: + logger.warning( + f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." + ) + self.vocab_size = vocab_size + elif vocab_size > self.vocab_size: + logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") + self.vocab_size = vocab_size + if pad_token_id is not None and pad_token_id != self.pad_token_id: + logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") + if bos_token_id is not None and bos_token_id != self.bos_token_id: + logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") + if eos_token_id is not None and eos_token_id != self.eos_token_id: + logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") + + +class DynamicCache: + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. + + Args: + key_cache (`list[torch.Tensor]`): The list of key states. + value_cache (`list[torch.Tensor]`): The list of value states. + """ + + def __init__(self) -> None: + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + def update( + self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] + value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] + layer_idx (`int`): The index of the layer to update. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. + """ + # Lazy initialization + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 + return layer_seq_length + + def crop(self, max_length: int) -> None: + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + assert max_length > 0, "max_length must be positive" + + if self.get_seq_length() <= max_length: + return + + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +class KVCache: + def __init__(self, cache_size: int = 4) -> None: + self.cache_size = cache_size + self.tensor_input_field_names = [ + "input_ids", + "within_seq_position_ids", + "global_position_ids", + "sequence_ids", + "labels", + ] + self.tensor_output_field_names = ["logits", "embeddings"] + self.cache_dict: dict[str, DynamicCache] = {} + self.cache_queue: list[str] = [] + + def reset(self) -> None: + for k in list(self.cache_dict.keys()): + del self.cache_dict[k] + del self.cache_dict + self.cache_dict = {} + self.cache_queue = [] + + torch.cuda.empty_cache() + + def before_forward(self, batch: dict[str, torch.Tensor]) -> None: + contexts: list[str] | None = batch.get("context", None) + if contexts is None or "context_len" not in batch: + logger.warning_once( + "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." + ) + return + + context_lens: list[int] = list(set(batch["context_len"])) + contexts: list[str] = list(set(contexts)) # type: ignore[no-redef] + if len(contexts) != 1 or len(context_lens) != 1: + logger.warning( + "SingleContextKVCache requires a single context and context length. " + "Multiple contexts or context lengths found in a single batch. Skipping." + ) + return + + batch_size = batch["input_ids"].shape[0] + + unique_context = contexts[0] + unique_context_len = context_lens[0] + batch["use_cache"] = True + + if unique_context not in self.cache_dict: + return + + self.cache_dict[unique_context].batch_repeat_interleave(batch_size) + past_key_values = self.cache_dict[unique_context] + batch["past_key_values"] = past_key_values + + # Remove context from the input fields + for field_name in self.tensor_input_field_names: + if batch.get(field_name, None) is not None: + batch[field_name] = batch[field_name][:, unique_context_len:] + + def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None: + contexts = batch.get("context", None) + context_lens = batch.get("context_len", []) + if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: + return + + assert batch["use_cache"] + unique_context = contexts[0] + unique_context_len = context_lens[0] + + past_key_values = getattr(outputs, "past_key_values", None) + if not isinstance(past_key_values, DynamicCache): + logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") + return + + if "past_key_values" not in batch: + if len(self.cache_queue) == self.cache_size: + last_context = self.cache_queue.pop(0) + if last_context not in self.cache_queue: + del self.cache_dict[last_context] + torch.cuda.empty_cache() + + self.cache_dict[unique_context] = past_key_values + self.cache_queue.append(unique_context) + + # Remove context from the input fields + for field_name in self.tensor_input_field_names: + if field_name in batch and batch[field_name] is not None: + batch[field_name] = batch[field_name][:, unique_context_len:] + + # Remove context from the output fields + for field_name in self.tensor_output_field_names: + if field_name in outputs and outputs[field_name] is not None: + outputs[field_name] = outputs[field_name][:, unique_context_len:] + if "hidden_states" in outputs and outputs["hidden_states"] is not None: + outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] + + self.cache_dict[unique_context].crop(unique_context_len) + self.cache_dict[unique_context].batch_select_indices([0]) + + +class AttentionMethod(Enum): + FLASH = "flash" + FLEX = "flex" + + +class AttentionLayerType(Enum): + WITHIN_SEQ = "within_seq" + GLOBAL = "global" + + +class AttentionArgs(TypedDict, total=False): + flex_attention_args: FlexAttentionArgs + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, + num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RotaryPositionalEmbedding(nn.Module): + def __init__( + self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None + ): + super().__init__() + + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) + + @staticmethod + def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: + # Different from paper, but it uses a different permutation in order to obtain the same calculation + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + angles = torch.outer(t, self.inv_freq.to(device)) + angles = torch.cat((angles, angles), dim=1) + self.register_buffer("cos_cached", angles.cos(), persistent=False) + self.register_buffer("sin_cached", angles.sin(), persistent=False) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + # x: [bsz, seq_len, num_attention_heads, head_size] + device, dtype = q.device, q.dtype + seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len + + if seq_len > self.max_seq_len_cached: + self._set_sin_cos_cache(seq_len=seq_len, device=device) + + # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), + # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). + idxs = position_ids.to(device) + cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] + sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] + + # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is + # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is + # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.max_num_seqs = config.max_num_sequences + self.clip_qkv = config.clip_qkv + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + if self.config.global_attention_every_n_layers > 0: + self.layer_type = ( + AttentionLayerType.GLOBAL + if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 + else AttentionLayerType.WITHIN_SEQ + ) + else: + self.layer_type = AttentionLayerType.WITHIN_SEQ + + self.rope_theta = ( + config.rope_theta_within_seq + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else config.rope_theta_global + ) + self.max_position_embeddings = ( + config.max_num_positions_within_seq + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else config.max_num_positions_global + ) + + self.rotary_emb = RotaryPositionalEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta + ) + + def prepare_qkv( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + past_key_value: DynamicCache | None = None, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, q_len, _ = hidden_states.size() + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + val_states: torch.Tensor = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + if self.clip_qkv is not None: + query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) + key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) + val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) + + query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) + + if use_cache and past_key_value is not None: + key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + if input_dtype != target_dtype: + logger.warning_once( + f"The input hidden states seems to be silently casted in {input_dtype}. " + f"This might be because you have upcasted embedding or layer norm layers " + f"in {input_dtype}. We will cast back the input in {target_dtype}." + ) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + val_states = val_states.to(target_dtype) + + return query_states, key_states, val_states + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: + is_cache_prefilled = ( + use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 + ) + + query_states, key_states, val_states = self.prepare_qkv( + hidden_states=hidden_states, + position_ids=within_seq_position_ids + if self.layer_type == AttentionLayerType.WITHIN_SEQ + else global_position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + ) + + # Note: We fallback to using flash attention in inference mode when cache is filled with kv values + # for global attention layers instead of flex attention. This is because once the cache is filled, + # the last sequence attends to everything in the cache, so we can make things faster by using a + # bidirectional flash attention instead of block-causal flex attention. + if self.layer_type == AttentionLayerType.WITHIN_SEQ or is_cache_prefilled: + attention_type = AttentionMethod.FLASH + else: + attention_type = AttentionMethod.FLEX + + attn_output, attn_weights = self._attn( + attention_type=attention_type, + query_states=query_states, + key_states=key_states, + val_states=val_states, + sequence_ids=sequence_ids, + attention_args=attention_args, + output_attentions=output_attentions, + ) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + def _attn( + self, + attention_type: AttentionMethod, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + match attention_type: + case AttentionMethod.FLASH: + f = self._flash_attn + case AttentionMethod.FLEX: + f = self._flex_attn + case _: + raise ValueError(f"No attention implementation found for {attention_type}") + return f( + query_states=query_states, + key_states=key_states, + val_states=val_states, + sequence_ids=sequence_ids, + attention_args=attention_args, + output_attentions=output_attentions, + ) + + def _flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Flash attention implementation. + + Calls the public API of flash attention and deals with padding tokens if any are present. + """ + assert not output_attentions, "Flash attention doesn't support returning attention masks" + bsz, q_len = query_states.shape[0], query_states.shape[1] + _, kv_len = key_states.shape[0], key_states.shape[1] + + if self.layer_type == AttentionLayerType.GLOBAL: # Only happens in inference + q_sequence_ids = sequence_ids + if q_len < kv_len: + # Assumes query contain only one sequence + # and all tokens in query (except padding) will attend to all tokens in KV + first_token_id = sequence_ids[:, 0].unsqueeze(1) + k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) + else: + k_sequence_ids = sequence_ids + else: + if q_len < kv_len: # Only happens in inference + key_states = key_states[:, -q_len:] + val_states = val_states[:, -q_len:] + q_sequence_ids = k_sequence_ids = sequence_ids + + if is_flash_attention_available(): + attn_output = flash_attention_func( + query_states, + key_states, + val_states, + q_sequence_ids=q_sequence_ids, + k_sequence_ids=k_sequence_ids, + causal=False, + ) + else: + attn_output = varlen_flex_attention_func( + query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + return attn_output, None + + def _flex_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + val_states: torch.Tensor, + sequence_ids: torch.Tensor, + attention_args: AttentionArgs | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + bsz, q_len = query_states.shape[0], query_states.shape[1] + flex_attention_args = attention_args.get("flex_attention_args", None) if attention_args is not None else None + block_mask = flex_attention_args.get("block_mask", None) if flex_attention_args is not None else None + score_mod = flex_attention_args.get("score_mod", None) if flex_attention_args is not None else None + outputs = flex_attention_func(query_states, key_states, val_states, score_mod=score_mod, block_mask=block_mask) + + outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() + return outputs, None + + +class MLP(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.w2(self.act_fn(self.w1(hidden_states))) + + +class GLUMLP(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + hidden_states = self.w2(hidden_states) + return hidden_states + + +class FFN(nn.Module): + def __init__(self, config: E1Config): + super().__init__() + mlp_cls = GLUMLP if config.gated_mlp else MLP + self.mlp = mlp_cls(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(hidden_states) + + +@dataclass +class E1ModelOutputWithPast(ModelOutput): + """Base class for model's outputs, with potential hidden states and attentions. + + Attributes: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class E1MaskedLMOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + mlm_loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class E1ClassificationOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + past_key_values: DynamicCache | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + if layer_norm is None: + return torch.nn.functional.rms_norm( + hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon + ).to(input_dtype) + else: + return layer_norm.rms_norm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # no bias + residual=None, + eps=self.variance_epsilon, + dropout_p=0.0, # no dropout by default + prenorm=False, + residual_in_fp32=False, + ).to(input_dtype) + + +class NormAttentionNorm(nn.Module): + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.self_attn = Attention(config, layer_idx) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + return hidden_states, residual, self_attn_weights, present_key_value + + +class DecoderLayer(nn.Module): + def __init__(self, config: E1Config, layer_idx: int): + super().__init__() + self.initializer_range = config.initializer_range + self.hidden_size = config.hidden_size + self.norm_attn_norm = NormAttentionNorm(config, layer_idx) + self.ffn = FFN(config) + + def forward( + self, + hidden_states: torch.Tensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + attention_args: AttentionArgs | None = None, + past_key_value: DynamicCache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: + hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm( + hidden_states=hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # Fully Connected + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights, present_key_value + + +### Support for embedding datasets with low code +class _LegacyPooler: + def __init__(self, pooling_types: List[str]): + self.pooling_types = pooling_types + self.pooling_options = { + 'mean': self.mean_pooling, + 'max': self.max_pooling, + 'norm': self.norm_pooling, + 'median': self.median_pooling, + 'std': self.std_pooling, + 'var': self.var_pooling, + 'cls': self.cls_pooling, + 'parti': self._pool_parti, + } + + def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: + maxed_attentions = torch.max(attentions, dim=1)[0] + return maxed_attentions + + def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): + # Run PageRank on the attention matrix converted to a graph. + # Raises exceptions if the graph doesn't match the token sequence or has no edges. + # Returns the PageRank scores for each token node. + G = self._convert_to_graph(attention_matrix) + if G.number_of_nodes() != attention_matrix.shape[0]: + raise Exception( + f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") + if G.number_of_edges() == 0: + raise Exception(f"You don't seem to have any attention edges left in the graph.") + + return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) + + def _convert_to_graph(self, matrix): + # Convert a matrix (e.g., attention scores) to a directed graph using networkx. + # Each element in the matrix represents a directed edge with a weight. + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + return G + + def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): + # Remove keys where attention_mask is 0 + if attention_mask is not None: + for k in list(dict_importance.keys()): + if attention_mask[k] == 0: + del dict_importance[k] + + #dict_importance[0] # remove cls + #dict_importance[-1] # remove eos + total = sum(dict_importance.values()) + return np.array([v / total for _, v in dict_importance.items()]) + + def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) + maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() + # emb is (b, L, d), maxed_attentions is (b, L, L) + emb_pooled = [] + for e, a, mask in zip(emb, maxed_attentions, attention_mask): + dict_importance = self._page_rank(a) + importance_weights = self._calculate_importance_weights(dict_importance, mask) + num_tokens = int(mask.sum().item()) + emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) + pooled = torch.tensor(np.array(emb_pooled)) + return pooled + + def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.mean(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + + def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.max(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).max(dim=1).values + + def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.norm(dim=1, p=2) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).norm(dim=1, p=2) + + def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.median(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).median(dim=1).values + + def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.std(dim=1) + else: + # Compute variance correctly over non-masked positions, then take sqrt + var = self.var_pooling(emb, attention_mask, **kwargs) + return torch.sqrt(var) + + def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.var(dim=1) + else: + # Correctly compute variance over only non-masked positions + attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) + # Compute mean over non-masked positions + mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + mean = mean.unsqueeze(1) # (b, 1, d) + # Compute squared differences from mean, only over non-masked positions + squared_diff = (emb - mean) ** 2 # (b, L, d) + # Sum squared differences over non-masked positions and divide by count + var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + return var + + def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + return emb[:, 0, :] + + def __call__( + self, + emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attentions: Optional[torch.Tensor] = None + ): # [mean, max] + final_emb = [] + for pooling_type in self.pooling_types: + final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) + return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) + + +class _LegacyEmbeddingMixin: + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + def _read_sequences_from_db(self, db_path: str) -> set[str]: + """Read sequences from SQLite database.""" + import sqlite3 + sequences = [] + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute("SELECT sequence FROM embeddings") + while True: + row = c.fetchone() + if row is None: + break + sequences.append(row[0]) + return set(sequences) + + def embed_dataset( + self, + sequences: List[str], + #tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed + batch_size: int = 2, + max_len: int = 512, + truncate: bool = True, + full_embeddings: bool = False, + embed_dtype: torch.dtype = torch.float32, + pooling_types: List[str] = ['mean'], + sql: bool = False, + save: bool = True, + sql_db_path: str = 'embeddings.db', + save_path: str = 'embeddings.pth', + **kwargs, + ) -> Optional[dict[str, torch.Tensor]]: + """Embed a dataset of protein sequences. + + Args: + sequences: List of protein sequences + batch_size: Batch size for processing + max_len: Maximum sequence length + full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) + pooling_type: Type of pooling ('mean' or 'cls') + sql: Whether to store embeddings in SQLite database - will be stored in float32 + sql_db_path: Path to SQLite database + + Returns: + Dictionary mapping sequences to embeddings, or None if sql=True + + Note: + - If sql=True, embeddings can only be stored in float32 + - sql is ideal if you need to stream a very large dataset for training in real-time + - save=True is ideal if you can store the entire embedding dictionary in RAM + - sql will be used if it is True and save is True or False + - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences + - Sequences will be truncated to max_len and sorted by length in descending order for faster processing + + Example: + >>> embedder = EmbeddingMixin() + >>> embedding_dict = embedder.embed_dataset( + sequences=[ + 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences + ], + batch_size=2, # adjust for your GPU memory + max_len=512, # adjust for your needs + full_embeddings=False, # if True, no pooling is performed + embed_dtype=torch.float32, # cast to what dtype you want + pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together + sql=False, # if True, embeddings will be stored in SQLite database + sql_db_path='embeddings.db', + save=True, # if True, embeddings will be saved as a .pth file + save_path='embeddings.pth', + ) + >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql + """ + sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) + sequences = sorted(sequences, key=len, reverse=True) + hidden_size = self.config.hidden_size + pooler = Pooler(pooling_types) if not full_embeddings else None + + def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings + return residue_embeddings + else: + return pooler(residue_embeddings, attention_mask) + + if sql: + import sqlite3 + conn = sqlite3.connect(sql_db_path) + c = conn.cursor() + c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') + already_embedded = self._read_sequences_from_db(sql_db_path) + to_embed = [seq for seq in sequences if seq not in already_embedded] + print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") + print(f"Embedding {len(to_embed)} new sequences") + if len(to_embed) > 0: + with torch.no_grad(): + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + input_ids, attention_mask = self._embed(seqs, return_attention_mask=True) + embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32 + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes())) + conn.commit() + conn.commit() + conn.close() + return None + + embeddings_dict = {} + if os.path.exists(save_path): + embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True) + to_embed = [seq for seq in sequences if seq not in embeddings_dict] + print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") + print(f"Embedding {len(to_embed)} new sequences") + else: + to_embed = sequences + print(f"Embedding {len(to_embed)} new sequences") + + if len(to_embed) > 0: + with torch.no_grad(): + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True) + embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + embeddings_dict[seq] = emb.cpu() + + if save: + torch.save(embeddings_dict, save_path) + + return embeddings_dict + + +class E1PreTrainedModel(PreTrainedModel): + config_class = E1Config + config: E1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DecoderLayer"] + _transformer_layer_cls = [DecoderLayer] + _skip_keys_device_placement = "past_key_values" + all_tied_weights_keys = {} + + def _init_weights(self, module: nn.Module) -> None: + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + def post_init(self) -> None: + super().post_init() + + def _backward_compatibility_gradient_checkpointing(self) -> None: + if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable(dict(use_reentrant=False)) + + @property + def _device(self) -> torch.device: + return next(self.parameters()).device + + @classmethod + def from_pretrained( # type: ignore[no-untyped-def] + cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs + ) -> "E1PreTrainedModel": + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +class E1Model(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) + self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = E1BatchPreparer() + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.embed_tokens = value + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + **kwargs + ) -> E1ModelOutputWithPast: + """ + Args: + input_ids: (batch_size, seq_length) + within_seq_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the sequence itself. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] + global_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the global sequence. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] + sequence_ids: (batch_size, seq_length) + This tensor contains the sequence id of each residue. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] + past_key_values: DynamicCache + use_cache: bool + output_attentions: bool + output_hidden_states: bool + + Returns: + E1ModelOutputWithPast: Model Outputs + """ + batch_size, seq_length = input_ids.shape + + if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + elif not use_cache: + # To avoid weirdness with gradient checkpointing: https://github.com/huggingface/transformers/issues/28499 + past_key_values = None + + global_position_ids = global_position_ids.view(-1, seq_length).long() + within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() + sequence_ids = sequence_ids.view(-1, seq_length).long() + + max_position_id = torch.max(within_seq_position_ids).item() + min_position_id = torch.min(within_seq_position_ids).item() + assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( + f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" + ) + + inputs_embeds = self.embed_tokens(input_ids) + # -1 is used to indicate padding tokens, so we need to clamp the sequence ids to 0 + inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) + + # In case we need to do any manual typecasting + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype + hidden_states = inputs_embeds.to(target_dtype) + + # (batch_size, query_length, keyval_length) + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # Create block mask for flex attention + attention_args: AttentionArgs | None = None + if past_key_values_length == 0: + block_mask = create_block_causal_mask_optimized(sequence_ids) + flex_attention_args = FlexAttentionArgs(block_mask=block_mask) + attention_args = AttentionArgs(flex_attention_args=flex_attention_args) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + within_seq_position_ids, + global_position_ids, + sequence_ids, + attention_args, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + attention_args=attention_args, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states, self_attn_weights, present_key_value = layer_outputs + + if use_cache: + # NOTE: it's necessary to re-assign past_key_values because FSDP2 + # passes certain arguments by value, not by reference. + # See https://github.com/huggingface/transformers/issues/38190#issuecomment-2914016168 + next_decoder_cache = past_key_values = present_key_value + + if output_attentions: + all_self_attns += (self_attn_weights,) # type: ignore[operator] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + next_cache = next_decoder_cache if use_cache else None + + return E1ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: E1Model = E1Model(config) + self.vocab_size = config.vocab_size + self.mlm_head = torch.nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size, bias=True), + nn.GELU(), + nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, config.vocab_size, bias=True), + ) + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = E1BatchPreparer() + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + **kwargs, + ) -> E1MaskedLMOutputWithPast: + """ + Args: + input_ids: (batch_size, seq_length) + within_seq_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the sequence itself. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] + global_position_ids: (batch_size, seq_length) + This tensor contains the position of each residue within the global sequence. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] + sequence_ids: (batch_size, seq_length) + This tensor contains the sequence id of each residue. + For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], + the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] + labels: (batch_size, seq_length) + past_key_values: DynamicCache + use_cache: bool + output_attentions: bool + output_hidden_states: bool + + Returns: + E1MaskedLMOutputWithPast: Model Outputs + """ + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + x = outputs.last_hidden_state + loss = None + + # Compute masked language modeling loss + mlm_logits = self.mlm_head(x).float() + mlm_loss = 0.0 + if labels is not None: + mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) + mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) + mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") + mask = mlm_labels_flat != self.model.padding_idx + n_mlm = mask.sum() + mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) + loss = 0.0 + loss += mlm_loss + + return E1MaskedLMOutputWithPast( + loss=loss, + mlm_loss=mlm_loss, + logits=mlm_logits, + last_hidden_state=x, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: E1Model = E1Model(config) + self.vocab_size = config.vocab_size + self.num_labels = config.num_labels + self.classifier = nn.Sequential( + nn.Linear(config.hidden_size * 2, config.hidden_size * 4), + nn.GELU(), + nn.LayerNorm(config.hidden_size * 4), + nn.Linear(config.hidden_size * 4, config.num_labels), + ) + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + self.bce = nn.BCEWithLogitsLoss() + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = E1BatchPreparer() + + if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: + pooling_types = kwargs['pooling_types'] + else: + pooling_types = ['mean', 'var'] + self.pooler = Pooler(pooling_types) + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + **kwargs, + ) -> E1ClassificationOutputWithPast: + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + attention_mask = (sequence_ids != -1).long() + x = outputs.last_hidden_state + features = self.pooler(x, attention_mask) + logits = self.classifier(features) + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + if self.num_labels == 1: + loss = self.mse(logits.flatten(), labels.flatten()) + else: + loss = self.mse(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = self.bce(logits, labels) + + return E1ClassificationOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=x, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): + config: E1Config + config_class = E1Config + def __init__(self, config: E1Config, **kwargs): + E1PreTrainedModel.__init__(self, config, **kwargs) + self.model: E1Model = E1Model(config) + self.vocab_size = config.vocab_size + self.num_labels = config.num_labels + self.classifier = nn.Sequential( + nn.Linear(config.hidden_size * 2, config.hidden_size * 4), + nn.GELU(), + nn.LayerNorm(config.hidden_size * 4), + nn.Linear(config.hidden_size * 4, config.num_labels), + ) + self.loss_fct = nn.CrossEntropyLoss() + self.gradient_checkpointing = config.gradient_checkpointing + self.prep_tokens = E1BatchPreparer() + self.post_init() + + @property + def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: + return self.model.device_mesh + + @torch.inference_mode() + def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: + batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) + last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + if return_attention_mask: + attention_mask = (batch['sequence_ids'] != -1).long() + return last_hidden_state, attention_mask + else: + return last_hidden_state + + def forward( + self, + input_ids: torch.LongTensor, + within_seq_position_ids: torch.LongTensor, + global_position_ids: torch.LongTensor, + sequence_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + **kwargs, + ) -> E1ClassificationOutputWithPast: + outputs: E1ModelOutputWithPast = self.model( + input_ids=input_ids, + within_seq_position_ids=within_seq_position_ids, + global_position_ids=global_position_ids, + sequence_ids=sequence_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + x = outputs.last_hidden_state + logits = self.classifier(x) + loss = None + if labels is not None: + loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return E1ClassificationOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=x, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = E1ForSequenceClassification.from_pretrained("Profluent-Bio/E1-150m", dtype=torch.bfloat16, num_labels=1).eval().to(device) + print(model) + + seqs = [ + "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE", + "IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY", + "PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL", + "SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA", + ] + + batch = model.prep_tokens.get_batch_kwargs(seqs, device=device) + batch['labels'] = torch.tensor([0.0, 0.0, 0.0, 0.0], device=device) + + last_hidden_state = model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state + print(last_hidden_state.shape)