import os os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" import numpy as np import networkx as nx import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat from enum import Enum from typing import Any, TypedDict, Callable, Optional, List from dataclasses import dataclass from tokenizers import Tokenizer from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import ModelOutput from transformers.utils import logging from tqdm.auto import tqdm logger = logging.get_logger(__name__) ### Establish attention compatibility try: from flash_attn import flash_attn_func, flash_attn_varlen_func except ImportError: logger.warning("Failed to import flash attention; Will be using PyTorch attention instead") flash_attn_func = None flash_attn_varlen_func = None try: from torch.nn.attention.flex_attention import ( BlockMask, create_block_mask, flex_attention, _create_sparse_block_from_block_mask ) if torch.cuda.is_available(): # if on linux, compile the flex attention function if os.name == 'posix': print("Compiling flex attention") flex_attention = torch.compile(flex_attention, dynamic=True) else: print("Not compiling flex attention, detected non-Linux environment") except ImportError: logger.warning("Failed to import flex attention; Will be using PyTorch attention instead") flex_attention = None try: from kernels import get_kernel layer_norm = get_kernel("kernels-community/triton-layer-norm") except Exception as e: logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") layer_norm = None def is_flash_attention_available() -> bool: return ( flash_attn_func is not None and flash_attn_varlen_func is not None and (os.getenv("USE_FLASH_ATTN", "1") == "1") ) class FlexAttentionArgs(TypedDict, total=False): block_mask: BlockMask | None score_mod: Callable | None def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: # Assumes sequence_ids is sorted in increasing order for each batch item, except for # the -1 values, which are used to indicate the padding tokens. def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] return ( (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) & (sequence_ids[b, q_idx] != -1) & (sequence_ids[b, kv_idx] != -1) ) batch_size, seqlen = sequence_ids.shape return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) def flex_attention_func( query_states: torch.Tensor, # (bs, seqlen, nh, hs) key_states: torch.Tensor, # (bs, seqlen, nkv, hs) value_states: torch.Tensor, # (bs, seqlen, nkv, hs) score_mod: Callable | None = None, block_mask: BlockMask | None = None, ) -> torch.Tensor: assert flex_attention is not None, "Flex Attention is not available in this environment" assert score_mod is None, "Score mod is not supported yet" query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) outputs = flex_attention( query_states, key_states, value_states, block_mask=block_mask, score_mod=score_mod, enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh ) outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) return outputs def flash_attention_func( query_states: torch.Tensor, # (bs, seqlen, nh, hs) key_states: torch.Tensor, # (bs, seqlen, nkv, hs) value_states: torch.Tensor, # (bs, seqlen, nkv, hs) q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, causal: bool = False, ) -> torch.Tensor: # (bs, seqlen, nh, hs) # Contains at least one padding token in the sequence. Note: ignore attention mask if causal. if not is_flash_attention_available(): raise ImportError("Flash Attention is not available. Please install flash-attn.") if not causal: batch_size, q_len = query_states.shape[0], query_states.shape[1] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, causal=False, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) return attn_output class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def] ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( -1, *other_shape ) @staticmethod def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def] (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype ) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # grad_input[indices] = grad_output grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: device = SLEN.device total_tokens = torch.sum(SLEN) B = (total_tokens + block_size - 1) // block_size padding_tokens = B * block_size - total_tokens SLEN = torch.cat([SLEN, torch.Tensor([padding_tokens]).to(device)], dim=0) assert torch.sum(SLEN) == B * block_size # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) total_tokens = cum[-1].item() # Block start/end offsets [start, end) in token index space block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) # MIN_SEQ_ID[i] = first sequence whose end > block_start # searchsorted with right=True returns first index where cum > value MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) # For empty tail beyond total_tokens we already clipped block_ends. last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) return MIN_SEQ_ID, MAX_SEQ_ID def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) overlap = cond1 & cond2 cond1 = (MIN_Q == MAX_Q).unsqueeze(1) cond2 = (MIN_K == MAX_K).unsqueeze(0) same_seq_in_qk = cond1 & cond2 full_blocks = overlap & same_seq_in_qk partial_blocks = overlap & ~same_seq_in_qk return full_blocks, partial_blocks def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) partial_blocks = partial_blocks[None, None] full_blocks = full_blocks[None, None] q_doc_id = torch.repeat_interleave(SLEN_Q) k_doc_id = torch.repeat_interleave(SLEN_K) def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: return q_doc_id[q_idx] == k_doc_id[kv_idx] total_q_len = q_doc_id.shape[0] total_k_len = k_doc_id.shape[0] return _create_sparse_block_from_block_mask( (partial_blocks, full_blocks), doc_mask, seq_lengths=(total_q_len, total_k_len), Q_BLOCK_SIZE=128, KV_BLOCK_SIZE=128, ) def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: q_doc_id = torch.repeat_interleave(SLEN_Q) k_doc_id = torch.repeat_interleave(SLEN_K) def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: return q_doc_id[q_idx] == k_doc_id[kv_idx] total_q_len = q_doc_id.shape[0] total_k_len = k_doc_id.shape[0] return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) def varlen_flex_attention_func( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, ) -> torch.Tensor: batch_size, q_len = query_states.shape[0], query_states.shape[1] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] block_mask = block_mask_creator(seqlens_q, seqlens_k) attn_output_unpad = flex_attention( query_states, key_states, value_states, block_mask=block_mask, enable_gqa=query_states.shape[1] != key_states.shape[1], ) attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) return attn_output class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def] ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) return output @staticmethod def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def] (indices,) = ctx.saved_tensors # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. grad_values = grad_output[indices] # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) return grad_values, None, None index_put_first_axis = IndexPutFirstAxis.apply def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: """ Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. batch: int, batch size for the padded sequence. seqlen: int, maximum sequence length for the padded sequence. Return: hidden_states: (batch, seqlen, ...) """ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) # output[indices] = hidden_states output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: non_pad_indices = sequence_ids != -1 non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 sequence_ids = sequence_ids.flatten()[non_pad_indices] _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return non_pad_indices, cu_seqlens, max_seqlen_in_batch def _unpad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] assert query_layer.shape[:2] == q_sequence_ids.shape, ( f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" ) assert key_layer.shape[:2] == k_sequence_ids.shape, ( f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" ) assert query_length <= kv_seq_len, ( f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" ) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) if torch.equal(q_sequence_ids, k_sequence_ids): indices_q = indices_k cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k else: indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) index_first_axis = IndexFirstAxis.apply block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask PAD_TOKEN_ID = 0 def get_tokenizer() -> Tokenizer: try: fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") tokenizer: Tokenizer = Tokenizer.from_file(fname) except: print("E1 Tokenizer not found in local directory, downloading from Hugging Face") from huggingface_hub import hf_hub_download fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") tokenizer: Tokenizer = Tokenizer.from_file(fname) assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" ) return tokenizer @dataclass class DataPrepConfig: max_num_sequences: int = 512 max_num_positions_within_seq: int = 8192 remove_X_tokens: bool = False def get_context(sequence: str) -> str | None: if "," in sequence: return sequence.rsplit(",", 1)[0] return None class E1BatchPreparer: def __init__( self, data_prep_config: DataPrepConfig | None = None, tokenizer: Tokenizer | None = None, preserve_context_labels: bool = False, ): self.tokenizer = tokenizer or get_tokenizer() self.data_prep_config = data_prep_config or DataPrepConfig() self.pad_token_id = self.tokenizer.token_to_id("") self.preserve_context_labels = preserve_context_labels device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") self.boundary_token_ids = torch.tensor( [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device ).long() self.mask_token = "?" # nosec self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) self.X_token_id = self.tokenizer.token_to_id("X") self.vocab = self.tokenizer.get_vocab() def get_batch_kwargs( # type: ignore[override] self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False ) -> dict[str, torch.Tensor | list[str] | list[int]]: sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] return self.pad_encodings(sequence_encodings, device, non_blocking) def pad_encodings( self, sequence_encodings: list[dict[str, torch.Tensor]], device: torch.device = torch.device("cpu"), non_blocking: bool = False, ) -> dict[str, torch.Tensor | list[str] | list[int]]: non_blocking = non_blocking and device.type == "cuda" padded_encodings = {} # Note: We use -1 as the padding value for sequence and position ids because the 0 value # is a valid value for sequence and position ids. -1 is then used to distinguish valid # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. for key, padding_value in { "input_ids": self.pad_token_id, "sequence_ids": -1, "within_seq_position_ids": -1, "global_position_ids": -1, "labels": self.pad_token_id, }.items(): padded_encodings[key] = pad_sequence( [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value ).to(device=device, dtype=torch.long, non_blocking=non_blocking) padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] return padded_encodings def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: single_sequences = sequence.split(",") if len(single_sequences) > self.data_prep_config.max_num_sequences: raise ValueError( f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." ) single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) labels = torch.cat([x["labels"] for x in single_sequence_encodings]) within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) global_position_ids, ctx_len = [], 0 for encoding in single_sequence_encodings: global_position_ids.append(encoding["position_ids"] + ctx_len) ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) global_position_ids = torch.cat(global_position_ids) sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired context_len = sum(num_tokens[:-1]) context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) if not self.preserve_context_labels: labels[:context_len] = self.pad_token_id assert ( input_ids.shape == sequence_ids.shape == within_seq_position_ids.shape == global_position_ids.shape == labels.shape ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" return { "input_ids": input_ids, "sequence_ids": sequence_ids, "within_seq_position_ids": within_seq_position_ids, "global_position_ids": global_position_ids, "labels": labels, "context": context, "context_len": context_len, } def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: if not self.validate_sequence(sequence): raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") if len(sequence) > self.data_prep_config.max_num_positions_within_seq: raise ValueError( f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" ) # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` # but following is faster since our vocabulary is simple. tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) position_ids = torch.arange(len(tokens)) if self.data_prep_config.remove_X_tokens: X_positions = torch.where(tokens != self.X_token_id)[0] tokens = tokens[X_positions] position_ids = position_ids[X_positions] return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return torch.isin(tokens, self.boundary_token_ids) def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return tokens == self.mask_token_id def validate_sequence(self, sequence: str) -> bool: assert isinstance(sequence, str), "Sequence must be a string" sequence = sequence.replace(self.mask_token, "") return sequence.isalpha() and sequence.isupper() class E1Config(PretrainedConfig): model_type = "E1" keys_to_ignore_at_inference = ["past_key_values"] def __init__( # type: ignore self, # Model architecture/initialization vocab_size=None, hidden_size=4096, intermediate_size=16384, gated_mlp=False, num_hidden_layers=40, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", rms_norm_eps=1e-5, initializer_range=0.02, torch_dtype="bfloat16", gradient_checkpointing=False, no_ffn_gradient_checkpointing=False, # Tokenization pad_token_id=None, bos_token_id=None, eos_token_id=None, tie_word_embeddings=False, # Attention implementation & rotary positional embeddings global_attention_every_n_layers=0, max_num_sequences=512, max_num_positions_within_seq=8192, max_num_positions_global=1024 * 128, rope_theta_within_seq=10000.0, rope_theta_global=100000.0, clip_qkv=None, **kwargs, ) -> None: tokenizer = get_tokenizer() super().__init__( pad_token_id=tokenizer.token_to_id(""), bos_token_id=tokenizer.token_to_id(""), eos_token_id=tokenizer.token_to_id(""), tie_word_embeddings=tie_word_embeddings, torch_dtype=torch_dtype, **kwargs, ) self.hidden_size = hidden_size if intermediate_size is None: intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size self.intermediate_size = intermediate_size self.gated_mlp = gated_mlp self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_num_positions_within_seq = max_num_positions_within_seq self.max_num_positions_global = max_num_positions_global # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.rope_theta_within_seq = rope_theta_within_seq self.rope_theta_global = rope_theta_global self.max_num_sequences = max_num_sequences assert clip_qkv is None or clip_qkv > 0 self.clip_qkv = clip_qkv self.global_attention_every_n_layers = global_attention_every_n_layers self.vocab_size = tokenizer.get_vocab_size() self.gradient_checkpointing = gradient_checkpointing self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing if vocab_size is not None: if vocab_size < self.vocab_size: logger.warning( f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." ) self.vocab_size = vocab_size elif vocab_size > self.vocab_size: logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") self.vocab_size = vocab_size if pad_token_id is not None and pad_token_id != self.pad_token_id: logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") if bos_token_id is not None and bos_token_id != self.bos_token_id: logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") if eos_token_id is not None and eos_token_id != self.eos_token_id: logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") class DynamicCache: """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. Args: key_cache (`list[torch.Tensor]`): The list of key states. value_cache (`list[torch.Tensor]`): The list of value states. """ def __init__(self) -> None: self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Update the key and value caches in-place, and return the necessary keys and value states. Args: key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] layer_idx (`int`): The index of the layer to update. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. """ # Lazy initialization if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists for _ in range(len(self.key_cache), layer_idx): self.key_cache.append(torch.tensor([])) self.value_cache.append(torch.tensor([])) self.key_cache.append(key_states) self.value_cache.append(value_states) elif ( not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model ): # fills previously skipped layers; checking for tensor causes errors self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" is_empty_layer = ( len(self.key_cache) == 0 # no cache in any layer or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or not self.key_cache[layer_idx].numel() # the layer has no cache ) layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 return layer_seq_length def crop(self, max_length: int) -> None: """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" assert max_length > 0, "max_length must be positive" if self.get_seq_length() <= max_length: return for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] def batch_repeat_interleave(self, repeats: int) -> None: """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor) -> None: """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] class KVCache: def __init__(self, cache_size: int = 4) -> None: self.cache_size = cache_size self.tensor_input_field_names = [ "input_ids", "within_seq_position_ids", "global_position_ids", "sequence_ids", "labels", ] self.tensor_output_field_names = ["logits", "embeddings"] self.cache_dict: dict[str, DynamicCache] = {} self.cache_queue: list[str] = [] def reset(self) -> None: for k in list(self.cache_dict.keys()): del self.cache_dict[k] del self.cache_dict self.cache_dict = {} self.cache_queue = [] torch.cuda.empty_cache() def before_forward(self, batch: dict[str, torch.Tensor]) -> None: contexts: list[str] | None = batch.get("context", None) if contexts is None or "context_len" not in batch: logger.warning_once( "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." ) return context_lens: list[int] = list(set(batch["context_len"])) contexts: list[str] = list(set(contexts)) # type: ignore[no-redef] if len(contexts) != 1 or len(context_lens) != 1: logger.warning( "SingleContextKVCache requires a single context and context length. " "Multiple contexts or context lengths found in a single batch. Skipping." ) return batch_size = batch["input_ids"].shape[0] unique_context = contexts[0] unique_context_len = context_lens[0] batch["use_cache"] = True if unique_context not in self.cache_dict: return self.cache_dict[unique_context].batch_repeat_interleave(batch_size) past_key_values = self.cache_dict[unique_context] batch["past_key_values"] = past_key_values # Remove context from the input fields for field_name in self.tensor_input_field_names: if batch.get(field_name, None) is not None: batch[field_name] = batch[field_name][:, unique_context_len:] def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None: contexts = batch.get("context", None) context_lens = batch.get("context_len", []) if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: return assert batch["use_cache"] unique_context = contexts[0] unique_context_len = context_lens[0] past_key_values = getattr(outputs, "past_key_values", None) if not isinstance(past_key_values, DynamicCache): logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") return if "past_key_values" not in batch: if len(self.cache_queue) == self.cache_size: last_context = self.cache_queue.pop(0) if last_context not in self.cache_queue: del self.cache_dict[last_context] torch.cuda.empty_cache() self.cache_dict[unique_context] = past_key_values self.cache_queue.append(unique_context) # Remove context from the input fields for field_name in self.tensor_input_field_names: if field_name in batch and batch[field_name] is not None: batch[field_name] = batch[field_name][:, unique_context_len:] # Remove context from the output fields for field_name in self.tensor_output_field_names: if field_name in outputs and outputs[field_name] is not None: outputs[field_name] = outputs[field_name][:, unique_context_len:] if "hidden_states" in outputs and outputs["hidden_states"] is not None: outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] self.cache_dict[unique_context].crop(unique_context_len) self.cache_dict[unique_context].batch_select_indices([0]) class AttentionMethod(Enum): FLASH = "flash" FLEX = "flex" class AttentionLayerType(Enum): WITHIN_SEQ = "within_seq" GLOBAL = "global" class AttentionArgs(TypedDict, total=False): flex_attention_args: FlexAttentionArgs def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class RotaryPositionalEmbedding(nn.Module): def __init__( self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None ): super().__init__() self.dim = dim self.base = base self.max_position_embeddings = max_position_embeddings inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) @staticmethod def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: # Different from paper, but it uses a different permutation in order to obtain the same calculation self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) angles = torch.outer(t, self.inv_freq.to(device)) angles = torch.cat((angles, angles), dim=1) self.register_buffer("cos_cached", angles.cos(), persistent=False) self.register_buffer("sin_cached", angles.sin(), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: # x: [bsz, seq_len, num_attention_heads, head_size] device, dtype = q.device, q.dtype seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len if seq_len > self.max_seq_len_cached: self._set_sin_cos_cache(seq_len=seq_len, device=device) # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). idxs = position_ids.to(device) cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed class Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.max_num_seqs = config.max_num_sequences self.clip_qkv = config.clip_qkv if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) if self.config.global_attention_every_n_layers > 0: self.layer_type = ( AttentionLayerType.GLOBAL if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 else AttentionLayerType.WITHIN_SEQ ) else: self.layer_type = AttentionLayerType.WITHIN_SEQ self.rope_theta = ( config.rope_theta_within_seq if self.layer_type == AttentionLayerType.WITHIN_SEQ else config.rope_theta_global ) self.max_position_embeddings = ( config.max_num_positions_within_seq if self.layer_type == AttentionLayerType.WITHIN_SEQ else config.max_num_positions_global ) self.rotary_emb = RotaryPositionalEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta ) def prepare_qkv( self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: DynamicCache | None = None, use_cache: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, q_len, _ = hidden_states.size() query_states: torch.Tensor = self.q_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states) val_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) if self.clip_qkv is not None: query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) if use_cache and past_key_value is not None: key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype if input_dtype != target_dtype: logger.warning_once( f"The input hidden states seems to be silently casted in {input_dtype}. " f"This might be because you have upcasted embedding or layer norm layers " f"in {input_dtype}. We will cast back the input in {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) val_states = val_states.to(target_dtype) return query_states, key_states, val_states def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: AttentionArgs | None = None, past_key_value: DynamicCache | None = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: is_cache_prefilled = ( use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 ) query_states, key_states, val_states = self.prepare_qkv( hidden_states=hidden_states, position_ids=within_seq_position_ids if self.layer_type == AttentionLayerType.WITHIN_SEQ else global_position_ids, past_key_value=past_key_value, use_cache=use_cache, ) # Note: We fallback to using flash attention in inference mode when cache is filled with kv values # for global attention layers instead of flex attention. This is because once the cache is filled, # the last sequence attends to everything in the cache, so we can make things faster by using a # bidirectional flash attention instead of block-causal flex attention. if self.layer_type == AttentionLayerType.WITHIN_SEQ or is_cache_prefilled: attention_type = AttentionMethod.FLASH else: attention_type = AttentionMethod.FLEX attn_output, attn_weights = self._attn( attention_type=attention_type, query_states=query_states, key_states=key_states, val_states=val_states, sequence_ids=sequence_ids, attention_args=attention_args, output_attentions=output_attentions, ) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value def _attn( self, attention_type: AttentionMethod, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, sequence_ids: torch.Tensor, attention_args: AttentionArgs | None = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: match attention_type: case AttentionMethod.FLASH: f = self._flash_attn case AttentionMethod.FLEX: f = self._flex_attn case _: raise ValueError(f"No attention implementation found for {attention_type}") return f( query_states=query_states, key_states=key_states, val_states=val_states, sequence_ids=sequence_ids, attention_args=attention_args, output_attentions=output_attentions, ) def _flash_attn( self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, sequence_ids: torch.Tensor, attention_args: AttentionArgs | None = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Flash attention implementation. Calls the public API of flash attention and deals with padding tokens if any are present. """ assert not output_attentions, "Flash attention doesn't support returning attention masks" bsz, q_len = query_states.shape[0], query_states.shape[1] _, kv_len = key_states.shape[0], key_states.shape[1] if self.layer_type == AttentionLayerType.GLOBAL: # Only happens in inference q_sequence_ids = sequence_ids if q_len < kv_len: # Assumes query contain only one sequence # and all tokens in query (except padding) will attend to all tokens in KV first_token_id = sequence_ids[:, 0].unsqueeze(1) k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) else: k_sequence_ids = sequence_ids else: if q_len < kv_len: # Only happens in inference key_states = key_states[:, -q_len:] val_states = val_states[:, -q_len:] q_sequence_ids = k_sequence_ids = sequence_ids if is_flash_attention_available(): attn_output = flash_attention_func( query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids, causal=False, ) else: attn_output = varlen_flex_attention_func( query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() return attn_output, None def _flex_attn( self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, sequence_ids: torch.Tensor, attention_args: AttentionArgs | None = None, output_attentions: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: bsz, q_len = query_states.shape[0], query_states.shape[1] flex_attention_args = attention_args.get("flex_attention_args", None) if attention_args is not None else None block_mask = flex_attention_args.get("block_mask", None) if flex_attention_args is not None else None score_mod = flex_attention_args.get("score_mod", None) if flex_attention_args is not None else None outputs = flex_attention_func(query_states, key_states, val_states, score_mod=score_mod, block_mask=block_mask) outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() return outputs, None class MLP(nn.Module): def __init__(self, config: E1Config): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.w2(self.act_fn(self.w1(hidden_states))) class GLUMLP(nn.Module): def __init__(self, config: E1Config): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) hidden_states = self.w2(hidden_states) return hidden_states class FFN(nn.Module): def __init__(self, config: E1Config): super().__init__() mlp_cls = GLUMLP if config.gated_mlp else MLP self.mlp = mlp_cls(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(hidden_states) @dataclass class E1ModelOutputWithPast(ModelOutput): """Base class for model's outputs, with potential hidden states and attentions. Attributes: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: torch.FloatTensor | None = None past_key_values: DynamicCache | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None @dataclass class E1MaskedLMOutputWithPast(ModelOutput): loss: torch.FloatTensor | None = None mlm_loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None past_key_values: DynamicCache | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None @dataclass class E1ClassificationOutputWithPast(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None past_key_values: DynamicCache | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype if layer_norm is None: return torch.nn.functional.rms_norm( hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon ).to(input_dtype) else: return layer_norm.rms_norm_fn( x=hidden_states, weight=self.weight, bias=None, # no bias residual=None, eps=self.variance_epsilon, dropout_p=0.0, # no dropout by default prenorm=False, residual_in_fp32=False, ).to(input_dtype) class NormAttentionNorm(nn.Module): def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.self_attn = Attention(config, layer_idx) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: AttentionArgs | None = None, past_key_value: DynamicCache | None = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) return hidden_states, residual, self_attn_weights, present_key_value class DecoderLayer(nn.Module): def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.initializer_range = config.initializer_range self.hidden_size = config.hidden_size self.norm_attn_norm = NormAttentionNorm(config, layer_idx) self.ffn = FFN(config) def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: AttentionArgs | None = None, past_key_value: DynamicCache | None = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]: hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm( hidden_states=hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) # Fully Connected hidden_states = self.ffn(hidden_states) hidden_states = residual + hidden_states return hidden_states, self_attn_weights, present_key_value ### Support for embedding datasets with low code class Pooler: def __init__(self, pooling_types: List[str]): self.pooling_types = pooling_types self.pooling_options = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'norm': self.norm_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, 'parti': self._pool_parti, } def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: maxed_attentions = torch.max(attentions, dim=1)[0] return maxed_attentions def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): # Run PageRank on the attention matrix converted to a graph. # Raises exceptions if the graph doesn't match the token sequence or has no edges. # Returns the PageRank scores for each token node. G = self._convert_to_graph(attention_matrix) if G.number_of_nodes() != attention_matrix.shape[0]: raise Exception( f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") if G.number_of_edges() == 0: raise Exception(f"You don't seem to have any attention edges left in the graph.") return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) def _convert_to_graph(self, matrix): # Convert a matrix (e.g., attention scores) to a directed graph using networkx. # Each element in the matrix represents a directed edge with a weight. G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) return G def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): # Remove keys where attention_mask is 0 if attention_mask is not None: for k in list(dict_importance.keys()): if attention_mask[k] == 0: del dict_importance[k] #dict_importance[0] # remove cls #dict_importance[-1] # remove eos total = sum(dict_importance.values()) return np.array([v / total for _, v in dict_importance.items()]) def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() # emb is (b, L, d), maxed_attentions is (b, L, L) emb_pooled = [] for e, a, mask in zip(emb, maxed_attentions, attention_mask): dict_importance = self._page_rank(a) importance_weights = self._calculate_importance_weights(dict_importance, mask) num_tokens = int(mask.sum().item()) emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) pooled = torch.tensor(np.array(emb_pooled)) return pooled def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.max(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).max(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.median(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).median(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.std(dim=1) else: # Compute variance correctly over non-masked positions, then take sqrt var = self.var_pooling(emb, attention_mask, **kwargs) return torch.sqrt(var) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.var(dim=1) else: # Correctly compute variance over only non-masked positions attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) # Compute mean over non-masked positions mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) mean = mean.unsqueeze(1) # (b, 1, d) # Compute squared differences from mean, only over non-masked positions squared_diff = (emb - mean) ** 2 # (b, L, d) # Sum squared differences over non-masked positions and divide by count var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) return var def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) return emb[:, 0, :] def __call__( self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attentions: Optional[torch.Tensor] = None ): # [mean, max] final_emb = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> set[str]: """Read sequences from SQLite database.""" import sqlite3 sequences = [] with sqlite3.connect(db_path) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") while True: row = c.fetchone() if row is None: break sequences.append(row[0]) return set(sequences) def embed_dataset( self, sequences: List[str], #tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed batch_size: int = 2, max_len: int = 512, truncate: bool = True, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', **kwargs, ) -> Optional[dict[str, torch.Tensor]]: """Embed a dataset of protein sequences. Args: sequences: List of protein sequences batch_size: Batch size for processing max_len: Maximum sequence length full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) pooling_type: Type of pooling ('mean' or 'cls') sql: Whether to store embeddings in SQLite database - will be stored in float32 sql_db_path: Path to SQLite database Returns: Dictionary mapping sequences to embeddings, or None if sql=True Note: - If sql=True, embeddings can only be stored in float32 - sql is ideal if you need to stream a very large dataset for training in real-time - save=True is ideal if you can store the entire embedding dictionary in RAM - sql will be used if it is True and save is True or False - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences - Sequences will be truncated to max_len and sorted by length in descending order for faster processing Example: >>> embedder = EmbeddingMixin() >>> embedding_dict = embedder.embed_dataset( sequences=[ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences ], batch_size=2, # adjust for your GPU memory max_len=512, # adjust for your needs full_embeddings=False, # if True, no pooling is performed embed_dtype=torch.float32, # cast to what dtype you want pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together sql=False, # if True, embeddings will be stored in SQLite database sql_db_path='embeddings.db', save=True, # if True, embeddings will be saved as a .pth file save_path='embeddings.pth', ) >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql """ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) hidden_size = self.config.hidden_size pooler = Pooler(pooling_types) if not full_embeddings else None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings return residue_embeddings else: return pooler(residue_embeddings, attention_mask) if sql: import sqlite3 conn = sqlite3.connect(sql_db_path) c = conn.cursor() c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] input_ids, attention_mask = self._embed(seqs, return_attention_mask=True) embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32 for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes())) conn.commit() conn.commit() conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True) embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) embeddings_dict[seq] = emb.cpu() if save: torch.save(embeddings_dict, save_path) return embeddings_dict class E1PreTrainedModel(PreTrainedModel): config_class = E1Config config: E1Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DecoderLayer"] _transformer_layer_cls = [DecoderLayer] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, RMSNorm): module.weight.data.fill_(1.0) def post_init(self) -> None: super().post_init() def _backward_compatibility_gradient_checkpointing(self) -> None: if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): self.gradient_checkpointing_enable(dict(use_reentrant=False)) @property def _device(self) -> torch.device: return next(self.parameters()).device @classmethod def from_pretrained( # type: ignore[no-untyped-def] cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs ) -> "E1PreTrainedModel": return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) class E1Model(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = E1BatchPreparer() self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: self.embed_tokens = value @torch.inference_mode() def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state # Ignore copy def forward( self, input_ids: torch.LongTensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, past_key_values: DynamicCache | None = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, **kwargs ) -> E1ModelOutputWithPast: """ Args: input_ids: (batch_size, seq_length) within_seq_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the sequence itself. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] global_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the global sequence. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] sequence_ids: (batch_size, seq_length) This tensor contains the sequence id of each residue. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] past_key_values: DynamicCache use_cache: bool output_attentions: bool output_hidden_states: bool Returns: E1ModelOutputWithPast: Model Outputs """ batch_size, seq_length = input_ids.shape if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() elif not use_cache: # To avoid weirdness with gradient checkpointing: https://github.com/huggingface/transformers/issues/28499 past_key_values = None global_position_ids = global_position_ids.view(-1, seq_length).long() within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() sequence_ids = sequence_ids.view(-1, seq_length).long() max_position_id = torch.max(within_seq_position_ids).item() min_position_id = torch.min(within_seq_position_ids).item() assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" ) inputs_embeds = self.embed_tokens(input_ids) # -1 is used to indicate padding tokens, so we need to clamp the sequence ids to 0 inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) # In case we need to do any manual typecasting if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype hidden_states = inputs_embeds.to(target_dtype) # (batch_size, query_length, keyval_length) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # Create block mask for flex attention attention_args: AttentionArgs | None = None if past_key_values_length == 0: block_mask = create_block_causal_mask_optimized(sequence_ids) flex_attention_args = FlexAttentionArgs(block_mask=block_mask) attention_args = AttentionArgs(flex_attention_args=flex_attention_args) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, within_seq_position_ids, global_position_ids, sequence_ids, attention_args, past_key_values, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states, self_attn_weights, present_key_value = layer_outputs if use_cache: # NOTE: it's necessary to re-assign past_key_values because FSDP2 # passes certain arguments by value, not by reference. # See https://github.com/huggingface/transformers/issues/38190#issuecomment-2914016168 next_decoder_cache = past_key_values = present_key_value if output_attentions: all_self_attns += (self_attn_weights,) # type: ignore[operator] hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] next_cache = next_decoder_cache if use_cache else None return E1ModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: E1Model = E1Model(config) self.vocab_size = config.vocab_size self.mlm_head = torch.nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size, bias=True), nn.GELU(), nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), nn.Linear(config.hidden_size, config.vocab_size, bias=True), ) self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = E1BatchPreparer() self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh @torch.inference_mode() def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: torch.LongTensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, labels: torch.LongTensor | None = None, past_key_values: DynamicCache | None = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, **kwargs, ) -> E1MaskedLMOutputWithPast: """ Args: input_ids: (batch_size, seq_length) within_seq_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the sequence itself. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] global_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the global sequence. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] sequence_ids: (batch_size, seq_length) This tensor contains the sequence id of each residue. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] labels: (batch_size, seq_length) past_key_values: DynamicCache use_cache: bool output_attentions: bool output_hidden_states: bool Returns: E1MaskedLMOutputWithPast: Model Outputs """ outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) x = outputs.last_hidden_state loss = None # Compute masked language modeling loss mlm_logits = self.mlm_head(x).float() mlm_loss = 0.0 if labels is not None: mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") mask = mlm_labels_flat != self.model.padding_idx n_mlm = mask.sum() mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) loss = 0.0 loss += mlm_loss return E1MaskedLMOutputWithPast( loss=loss, mlm_loss=mlm_loss, logits=mlm_logits, last_hidden_state=x, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: E1Model = E1Model(config) self.vocab_size = config.vocab_size self.num_labels = config.num_labels self.classifier = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size * 4), nn.GELU(), nn.LayerNorm(config.hidden_size * 4), nn.Linear(config.hidden_size * 4, config.num_labels), ) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = E1BatchPreparer() if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: pooling_types = kwargs['pooling_types'] else: pooling_types = ['mean', 'var'] self.pooler = Pooler(pooling_types) self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh @torch.inference_mode() def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: torch.LongTensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, labels: torch.LongTensor | None = None, past_key_values: DynamicCache | None = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, **kwargs, ) -> E1ClassificationOutputWithPast: outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) attention_mask = (sequence_ids != -1).long() x = outputs.last_hidden_state features = self.pooler(x, attention_mask) logits = self.classifier(features) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.flatten(), labels.flatten()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return E1ClassificationOutputWithPast( loss=loss, logits=logits, last_hidden_state=x, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: E1Model = E1Model(config) self.vocab_size = config.vocab_size self.num_labels = config.num_labels self.classifier = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size * 4), nn.GELU(), nn.LayerNorm(config.hidden_size * 4), nn.Linear(config.hidden_size * 4, config.num_labels), ) self.loss_fct = nn.CrossEntropyLoss() self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = E1BatchPreparer() self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh @torch.inference_mode() def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: torch.LongTensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, labels: torch.LongTensor | None = None, past_key_values: DynamicCache | None = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, **kwargs, ) -> E1ClassificationOutputWithPast: outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) x = outputs.last_hidden_state logits = self.classifier(x) loss = None if labels is not None: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return E1ClassificationOutputWithPast( loss=loss, logits=logits, last_hidden_state=x, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = E1ForSequenceClassification.from_pretrained("Profluent-Bio/E1-150m", dtype=torch.bfloat16, num_labels=1).eval().to(device) print(model) seqs = [ "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE", "IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY", "PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL", "SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA", ] batch = model.prep_tokens.get_batch_kwargs(seqs, device=device) batch['labels'] = torch.tensor([0.0, 0.0, 0.0, 0.0], device=device) last_hidden_state = model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state print(last_hidden_state.shape)