|
|
import os
|
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
|
|
|
|
|
import numpy as np
|
|
|
import networkx as nx
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
from enum import Enum
|
|
|
from typing import Any, TypedDict, Callable, Optional, List
|
|
|
from dataclasses import dataclass
|
|
|
from tokenizers import Tokenizer
|
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
from transformers.activations import ACT2FN
|
|
|
from transformers.modeling_outputs import ModelOutput
|
|
|
from transformers.utils import logging
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
except ImportError:
|
|
|
logger.warning("Failed to import flash attention; Will be using PyTorch attention instead")
|
|
|
flash_attn_func = None
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
|
try:
|
|
|
from torch.nn.attention.flex_attention import (
|
|
|
BlockMask,
|
|
|
create_block_mask,
|
|
|
flex_attention,
|
|
|
_create_sparse_block_from_block_mask
|
|
|
)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
|
if os.name == 'posix':
|
|
|
print("Compiling flex attention")
|
|
|
flex_attention = torch.compile(flex_attention, dynamic=True)
|
|
|
else:
|
|
|
print("Not compiling flex attention, detected non-Linux environment")
|
|
|
|
|
|
except ImportError:
|
|
|
logger.warning("Failed to import flex attention; Will be using PyTorch attention instead")
|
|
|
flex_attention = None
|
|
|
|
|
|
try:
|
|
|
from kernels import get_kernel
|
|
|
layer_norm = get_kernel("kernels-community/triton-layer-norm")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead")
|
|
|
layer_norm = None
|
|
|
|
|
|
|
|
|
def is_flash_attention_available() -> bool:
|
|
|
return (
|
|
|
flash_attn_func is not None and flash_attn_varlen_func is not None and (os.getenv("USE_FLASH_ATTN", "1") == "1")
|
|
|
)
|
|
|
|
|
|
|
|
|
class FlexAttentionArgs(TypedDict, total=False):
|
|
|
block_mask: BlockMask | None
|
|
|
score_mod: Callable | None
|
|
|
|
|
|
|
|
|
def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask:
|
|
|
|
|
|
|
|
|
def document_mask(b, h, q_idx, kv_idx):
|
|
|
return (
|
|
|
(sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx])
|
|
|
& (sequence_ids[b, q_idx] != -1)
|
|
|
& (sequence_ids[b, kv_idx] != -1)
|
|
|
)
|
|
|
|
|
|
batch_size, seqlen = sequence_ids.shape
|
|
|
return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device)
|
|
|
|
|
|
|
|
|
def flex_attention_func(
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
value_states: torch.Tensor,
|
|
|
score_mod: Callable | None = None,
|
|
|
block_mask: BlockMask | None = None,
|
|
|
) -> torch.Tensor:
|
|
|
assert flex_attention is not None, "Flex Attention is not available in this environment"
|
|
|
assert score_mod is None, "Score mod is not supported yet"
|
|
|
query_states = query_states.transpose(1, 2).contiguous()
|
|
|
key_states = key_states.transpose(1, 2).contiguous()
|
|
|
value_states = value_states.transpose(1, 2).contiguous()
|
|
|
|
|
|
outputs = flex_attention(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
block_mask=block_mask,
|
|
|
score_mod=score_mod,
|
|
|
enable_gqa=query_states.shape[1] != key_states.shape[1],
|
|
|
)
|
|
|
|
|
|
outputs = outputs.transpose(1, 2)
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
def flash_attention_func(
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
value_states: torch.Tensor,
|
|
|
q_sequence_ids: torch.Tensor,
|
|
|
k_sequence_ids: torch.Tensor,
|
|
|
causal: bool = False,
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
if not is_flash_attention_available():
|
|
|
raise ImportError("Flash Attention is not available. Please install flash-attn.")
|
|
|
|
|
|
if not causal:
|
|
|
batch_size, q_len = query_states.shape[0], query_states.shape[1]
|
|
|
(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
indices_q,
|
|
|
(cu_seqlens_q, cu_seqlens_k),
|
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
|
|
|
|
|
|
attn_output_unpad = flash_attn_varlen_func(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
max_seqlen_q=max_seqlen_in_batch_q,
|
|
|
max_seqlen_k=max_seqlen_in_batch_k,
|
|
|
causal=False,
|
|
|
)
|
|
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
|
|
|
|
|
else:
|
|
|
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
|
|
|
|
|
return attn_output
|
|
|
|
|
|
|
|
|
class IndexFirstAxis(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def forward(ctx, input, indices) -> torch.Tensor:
|
|
|
ctx.save_for_backward(indices)
|
|
|
assert input.ndim >= 2
|
|
|
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
|
|
second_dim = other_shape.numel()
|
|
|
|
|
|
|
|
|
return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(
|
|
|
-1, *other_shape
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
|
|
|
(indices,) = ctx.saved_tensors
|
|
|
assert grad_output.ndim >= 2
|
|
|
other_shape = grad_output.shape[1:]
|
|
|
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
|
|
grad_input = torch.zeros(
|
|
|
[ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
|
|
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
|
|
|
|
|
|
|
|
def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
device = SLEN.device
|
|
|
total_tokens = torch.sum(SLEN)
|
|
|
B = (total_tokens + block_size - 1) // block_size
|
|
|
padding_tokens = B * block_size - total_tokens
|
|
|
SLEN = torch.cat([SLEN, torch.Tensor([padding_tokens]).to(device)], dim=0)
|
|
|
|
|
|
assert torch.sum(SLEN) == B * block_size
|
|
|
|
|
|
|
|
|
cum = torch.cumsum(SLEN.to(torch.long), dim=0)
|
|
|
total_tokens = cum[-1].item()
|
|
|
|
|
|
|
|
|
block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long)
|
|
|
block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device))
|
|
|
|
|
|
|
|
|
|
|
|
MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True)
|
|
|
|
|
|
|
|
|
|
|
|
last_token_in_block = torch.clamp(block_ends - 1, min=0)
|
|
|
MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True)
|
|
|
|
|
|
return MIN_SEQ_ID, MAX_SEQ_ID
|
|
|
|
|
|
|
|
|
def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q)
|
|
|
MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K)
|
|
|
|
|
|
cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0)
|
|
|
cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1)
|
|
|
overlap = cond1 & cond2
|
|
|
|
|
|
cond1 = (MIN_Q == MAX_Q).unsqueeze(1)
|
|
|
cond2 = (MIN_K == MAX_K).unsqueeze(0)
|
|
|
same_seq_in_qk = cond1 & cond2
|
|
|
|
|
|
full_blocks = overlap & same_seq_in_qk
|
|
|
partial_blocks = overlap & ~same_seq_in_qk
|
|
|
|
|
|
return full_blocks, partial_blocks
|
|
|
|
|
|
|
|
|
def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
|
|
|
full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K)
|
|
|
partial_blocks = partial_blocks[None, None]
|
|
|
full_blocks = full_blocks[None, None]
|
|
|
|
|
|
q_doc_id = torch.repeat_interleave(SLEN_Q)
|
|
|
k_doc_id = torch.repeat_interleave(SLEN_K)
|
|
|
|
|
|
def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
|
|
|
return q_doc_id[q_idx] == k_doc_id[kv_idx]
|
|
|
|
|
|
total_q_len = q_doc_id.shape[0]
|
|
|
total_k_len = k_doc_id.shape[0]
|
|
|
|
|
|
return _create_sparse_block_from_block_mask(
|
|
|
(partial_blocks, full_blocks),
|
|
|
doc_mask,
|
|
|
seq_lengths=(total_q_len, total_k_len),
|
|
|
Q_BLOCK_SIZE=128,
|
|
|
KV_BLOCK_SIZE=128,
|
|
|
)
|
|
|
|
|
|
|
|
|
def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
|
|
|
q_doc_id = torch.repeat_interleave(SLEN_Q)
|
|
|
k_doc_id = torch.repeat_interleave(SLEN_K)
|
|
|
|
|
|
def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
|
|
|
return q_doc_id[q_idx] == k_doc_id[kv_idx]
|
|
|
|
|
|
total_q_len = q_doc_id.shape[0]
|
|
|
total_k_len = k_doc_id.shape[0]
|
|
|
|
|
|
return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device)
|
|
|
|
|
|
|
|
|
def varlen_flex_attention_func(
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
value_states: torch.Tensor,
|
|
|
q_sequence_ids: torch.Tensor,
|
|
|
k_sequence_ids: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
|
batch_size, q_len = query_states.shape[0], query_states.shape[1]
|
|
|
(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
indices_q,
|
|
|
(cu_seqlens_q, cu_seqlens_k),
|
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
|
|
|
|
|
|
query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous()
|
|
|
key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous()
|
|
|
value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous()
|
|
|
|
|
|
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
|
|
seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
|
|
block_mask = block_mask_creator(seqlens_q, seqlens_k)
|
|
|
|
|
|
attn_output_unpad = flex_attention(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
block_mask=block_mask,
|
|
|
enable_gqa=query_states.shape[1] != key_states.shape[1],
|
|
|
)
|
|
|
|
|
|
attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len)
|
|
|
|
|
|
return attn_output
|
|
|
|
|
|
|
|
|
class IndexPutFirstAxis(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
|
|
|
ctx.save_for_backward(indices)
|
|
|
assert indices.ndim == 1
|
|
|
assert values.ndim >= 2
|
|
|
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
|
|
|
|
|
|
output[indices] = values
|
|
|
|
|
|
return output
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
|
|
|
(indices,) = ctx.saved_tensors
|
|
|
|
|
|
grad_values = grad_output[indices]
|
|
|
|
|
|
return grad_values, None, None
|
|
|
|
|
|
|
|
|
index_put_first_axis = IndexPutFirstAxis.apply
|
|
|
|
|
|
|
|
|
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
|
|
|
"""
|
|
|
Arguments:
|
|
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
|
|
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
|
|
batch: int, batch size for the padded sequence.
|
|
|
seqlen: int, maximum sequence length for the padded sequence.
|
|
|
Return:
|
|
|
hidden_states: (batch, seqlen, ...)
|
|
|
"""
|
|
|
|
|
|
|
|
|
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
|
|
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
|
|
|
|
|
|
|
def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
|
non_pad_indices = sequence_ids != -1
|
|
|
non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten()
|
|
|
sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5
|
|
|
sequence_ids = sequence_ids.flatten()[non_pad_indices]
|
|
|
_, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True)
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
return non_pad_indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
|
|
|
|
|
def _unpad_input(
|
|
|
query_layer: torch.Tensor,
|
|
|
key_layer: torch.Tensor,
|
|
|
value_layer: torch.Tensor,
|
|
|
q_sequence_ids: torch.Tensor,
|
|
|
k_sequence_ids: torch.Tensor,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
|
|
|
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2]
|
|
|
assert query_layer.shape[:2] == q_sequence_ids.shape, (
|
|
|
f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}"
|
|
|
)
|
|
|
assert key_layer.shape[:2] == k_sequence_ids.shape, (
|
|
|
f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}"
|
|
|
)
|
|
|
assert query_length <= kv_seq_len, (
|
|
|
f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}"
|
|
|
)
|
|
|
|
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids)
|
|
|
|
|
|
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
|
|
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
|
|
|
|
|
if torch.equal(q_sequence_ids, k_sequence_ids):
|
|
|
indices_q = indices_k
|
|
|
cu_seqlens_q = cu_seqlens_k
|
|
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
else:
|
|
|
indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids)
|
|
|
|
|
|
query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q)
|
|
|
|
|
|
assert cu_seqlens_q.shape == cu_seqlens_k.shape, (
|
|
|
f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}"
|
|
|
)
|
|
|
|
|
|
return (
|
|
|
query_layer,
|
|
|
key_layer,
|
|
|
value_layer,
|
|
|
indices_q,
|
|
|
(cu_seqlens_q, cu_seqlens_k),
|
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
)
|
|
|
|
|
|
|
|
|
index_first_axis = IndexFirstAxis.apply
|
|
|
block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask
|
|
|
PAD_TOKEN_ID = 0
|
|
|
|
|
|
|
|
|
def get_tokenizer() -> Tokenizer:
|
|
|
try:
|
|
|
fname = os.path.join(os.path.dirname(__file__), "tokenizer.json")
|
|
|
tokenizer: Tokenizer = Tokenizer.from_file(fname)
|
|
|
except:
|
|
|
print("E1 Tokenizer not found in local directory, downloading from Hugging Face")
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json")
|
|
|
tokenizer: Tokenizer = Tokenizer.from_file(fname)
|
|
|
assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, (
|
|
|
f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}"
|
|
|
)
|
|
|
|
|
|
return tokenizer
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class DataPrepConfig:
|
|
|
max_num_sequences: int = 512
|
|
|
max_num_positions_within_seq: int = 8192
|
|
|
remove_X_tokens: bool = False
|
|
|
|
|
|
|
|
|
def get_context(sequence: str) -> str | None:
|
|
|
if "," in sequence:
|
|
|
return sequence.rsplit(",", 1)[0]
|
|
|
return None
|
|
|
|
|
|
|
|
|
class E1BatchPreparer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
data_prep_config: DataPrepConfig | None = None,
|
|
|
tokenizer: Tokenizer | None = None,
|
|
|
preserve_context_labels: bool = False,
|
|
|
):
|
|
|
self.tokenizer = tokenizer or get_tokenizer()
|
|
|
self.data_prep_config = data_prep_config or DataPrepConfig()
|
|
|
self.pad_token_id = self.tokenizer.token_to_id("<pad>")
|
|
|
self.preserve_context_labels = preserve_context_labels
|
|
|
device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu")
|
|
|
self.boundary_token_ids = torch.tensor(
|
|
|
[self.tokenizer.token_to_id(token) for token in ["<bos>", "<eos>", "1", "2", "<pad>"]], device=device
|
|
|
).long()
|
|
|
self.mask_token = "?"
|
|
|
self.mask_token_id = self.tokenizer.token_to_id(self.mask_token)
|
|
|
self.X_token_id = self.tokenizer.token_to_id("X")
|
|
|
self.vocab = self.tokenizer.get_vocab()
|
|
|
|
|
|
def get_batch_kwargs(
|
|
|
self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False
|
|
|
) -> dict[str, torch.Tensor | list[str] | list[int]]:
|
|
|
sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences]
|
|
|
return self.pad_encodings(sequence_encodings, device, non_blocking)
|
|
|
|
|
|
def pad_encodings(
|
|
|
self,
|
|
|
sequence_encodings: list[dict[str, torch.Tensor]],
|
|
|
device: torch.device = torch.device("cpu"),
|
|
|
non_blocking: bool = False,
|
|
|
) -> dict[str, torch.Tensor | list[str] | list[int]]:
|
|
|
non_blocking = non_blocking and device.type == "cuda"
|
|
|
padded_encodings = {}
|
|
|
|
|
|
|
|
|
|
|
|
for key, padding_value in {
|
|
|
"input_ids": self.pad_token_id,
|
|
|
"sequence_ids": -1,
|
|
|
"within_seq_position_ids": -1,
|
|
|
"global_position_ids": -1,
|
|
|
"labels": self.pad_token_id,
|
|
|
}.items():
|
|
|
padded_encodings[key] = pad_sequence(
|
|
|
[enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value
|
|
|
).to(device=device, dtype=torch.long, non_blocking=non_blocking)
|
|
|
|
|
|
padded_encodings["context"] = [enc["context"] for enc in sequence_encodings]
|
|
|
padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings]
|
|
|
|
|
|
return padded_encodings
|
|
|
|
|
|
def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]:
|
|
|
single_sequences = sequence.split(",")
|
|
|
if len(single_sequences) > self.data_prep_config.max_num_sequences:
|
|
|
raise ValueError(
|
|
|
f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}"
|
|
|
" in the provided multi-sequence instance. Please remove some homologous sequences before trying again."
|
|
|
)
|
|
|
|
|
|
single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences]
|
|
|
|
|
|
num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings]
|
|
|
input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings])
|
|
|
labels = torch.cat([x["labels"] for x in single_sequence_encodings])
|
|
|
|
|
|
within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings])
|
|
|
global_position_ids, ctx_len = [], 0
|
|
|
for encoding in single_sequence_encodings:
|
|
|
global_position_ids.append(encoding["position_ids"] + ctx_len)
|
|
|
ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1)
|
|
|
global_position_ids = torch.cat(global_position_ids)
|
|
|
|
|
|
sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens))
|
|
|
|
|
|
|
|
|
context_len = sum(num_tokens[:-1])
|
|
|
context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False)
|
|
|
if not self.preserve_context_labels:
|
|
|
labels[:context_len] = self.pad_token_id
|
|
|
|
|
|
assert (
|
|
|
input_ids.shape
|
|
|
== sequence_ids.shape
|
|
|
== within_seq_position_ids.shape
|
|
|
== global_position_ids.shape
|
|
|
== labels.shape
|
|
|
), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape"
|
|
|
|
|
|
assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length"
|
|
|
|
|
|
return {
|
|
|
"input_ids": input_ids,
|
|
|
"sequence_ids": sequence_ids,
|
|
|
"within_seq_position_ids": within_seq_position_ids,
|
|
|
"global_position_ids": global_position_ids,
|
|
|
"labels": labels,
|
|
|
"context": context,
|
|
|
"context_len": context_len,
|
|
|
}
|
|
|
|
|
|
def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]:
|
|
|
if not self.validate_sequence(sequence):
|
|
|
raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only")
|
|
|
|
|
|
if len(sequence) > self.data_prep_config.max_num_positions_within_seq:
|
|
|
raise ValueError(
|
|
|
f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
tokens = torch.tensor([self.vocab[token] for token in ["<bos>", "1", *sequence, "2", "<eos>"]])
|
|
|
position_ids = torch.arange(len(tokens))
|
|
|
|
|
|
if self.data_prep_config.remove_X_tokens:
|
|
|
X_positions = torch.where(tokens != self.X_token_id)[0]
|
|
|
tokens = tokens[X_positions]
|
|
|
position_ids = position_ids[X_positions]
|
|
|
|
|
|
return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids}
|
|
|
|
|
|
def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
|
|
|
return torch.isin(tokens, self.boundary_token_ids)
|
|
|
|
|
|
def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
|
|
|
return tokens == self.mask_token_id
|
|
|
|
|
|
def validate_sequence(self, sequence: str) -> bool:
|
|
|
assert isinstance(sequence, str), "Sequence must be a string"
|
|
|
sequence = sequence.replace(self.mask_token, "")
|
|
|
return sequence.isalpha() and sequence.isupper()
|
|
|
|
|
|
|
|
|
|
|
|
class E1Config(PretrainedConfig):
|
|
|
model_type = "E1"
|
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
|
|
|
vocab_size=None,
|
|
|
hidden_size=4096,
|
|
|
intermediate_size=16384,
|
|
|
gated_mlp=False,
|
|
|
num_hidden_layers=40,
|
|
|
num_attention_heads=32,
|
|
|
num_key_value_heads=8,
|
|
|
hidden_act="silu",
|
|
|
rms_norm_eps=1e-5,
|
|
|
initializer_range=0.02,
|
|
|
torch_dtype="bfloat16",
|
|
|
gradient_checkpointing=False,
|
|
|
no_ffn_gradient_checkpointing=False,
|
|
|
|
|
|
pad_token_id=None,
|
|
|
bos_token_id=None,
|
|
|
eos_token_id=None,
|
|
|
tie_word_embeddings=False,
|
|
|
|
|
|
global_attention_every_n_layers=0,
|
|
|
max_num_sequences=512,
|
|
|
max_num_positions_within_seq=8192,
|
|
|
max_num_positions_global=1024 * 128,
|
|
|
rope_theta_within_seq=10000.0,
|
|
|
rope_theta_global=100000.0,
|
|
|
clip_qkv=None,
|
|
|
**kwargs,
|
|
|
) -> None:
|
|
|
tokenizer = get_tokenizer()
|
|
|
super().__init__(
|
|
|
pad_token_id=tokenizer.token_to_id("<pad>"),
|
|
|
bos_token_id=tokenizer.token_to_id("<bos>"),
|
|
|
eos_token_id=tokenizer.token_to_id("<eos>"),
|
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
|
torch_dtype=torch_dtype,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
if intermediate_size is None:
|
|
|
intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size
|
|
|
self.intermediate_size = intermediate_size
|
|
|
self.gated_mlp = gated_mlp
|
|
|
self.num_hidden_layers = num_hidden_layers
|
|
|
self.num_attention_heads = num_attention_heads
|
|
|
self.max_num_positions_within_seq = max_num_positions_within_seq
|
|
|
self.max_num_positions_global = max_num_positions_global
|
|
|
|
|
|
|
|
|
if num_key_value_heads is None:
|
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
|
self.hidden_act = hidden_act
|
|
|
self.initializer_range = initializer_range
|
|
|
self.rms_norm_eps = rms_norm_eps
|
|
|
self.rope_theta_within_seq = rope_theta_within_seq
|
|
|
self.rope_theta_global = rope_theta_global
|
|
|
self.max_num_sequences = max_num_sequences
|
|
|
assert clip_qkv is None or clip_qkv > 0
|
|
|
self.clip_qkv = clip_qkv
|
|
|
self.global_attention_every_n_layers = global_attention_every_n_layers
|
|
|
|
|
|
self.vocab_size = tokenizer.get_vocab_size()
|
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
|
self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing
|
|
|
|
|
|
if vocab_size is not None:
|
|
|
if vocab_size < self.vocab_size:
|
|
|
logger.warning(
|
|
|
f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL."
|
|
|
)
|
|
|
self.vocab_size = vocab_size
|
|
|
elif vocab_size > self.vocab_size:
|
|
|
logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.")
|
|
|
self.vocab_size = vocab_size
|
|
|
if pad_token_id is not None and pad_token_id != self.pad_token_id:
|
|
|
logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer")
|
|
|
if bos_token_id is not None and bos_token_id != self.bos_token_id:
|
|
|
logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer")
|
|
|
if eos_token_id is not None and eos_token_id != self.eos_token_id:
|
|
|
logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer")
|
|
|
|
|
|
|
|
|
class DynamicCache:
|
|
|
"""
|
|
|
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
|
|
It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`.
|
|
|
|
|
|
Args:
|
|
|
key_cache (`list[torch.Tensor]`): The list of key states.
|
|
|
value_cache (`list[torch.Tensor]`): The list of value states.
|
|
|
"""
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
self.key_cache: list[torch.Tensor] = []
|
|
|
self.value_cache: list[torch.Tensor] = []
|
|
|
|
|
|
def update(
|
|
|
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
|
|
Args:
|
|
|
key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim]
|
|
|
value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim]
|
|
|
layer_idx (`int`): The index of the layer to update.
|
|
|
|
|
|
Returns:
|
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim].
|
|
|
"""
|
|
|
|
|
|
if len(self.key_cache) <= layer_idx:
|
|
|
|
|
|
for _ in range(len(self.key_cache), layer_idx):
|
|
|
self.key_cache.append(torch.tensor([]))
|
|
|
self.value_cache.append(torch.tensor([]))
|
|
|
self.key_cache.append(key_states)
|
|
|
self.value_cache.append(value_states)
|
|
|
elif (
|
|
|
not self.key_cache[layer_idx].numel()
|
|
|
):
|
|
|
self.key_cache[layer_idx] = key_states
|
|
|
self.value_cache[layer_idx] = value_states
|
|
|
else:
|
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1)
|
|
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1)
|
|
|
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
|
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int:
|
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
is_empty_layer = (
|
|
|
len(self.key_cache) == 0
|
|
|
or len(self.key_cache) <= layer_idx
|
|
|
or not self.key_cache[layer_idx].numel()
|
|
|
)
|
|
|
layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0
|
|
|
return layer_seq_length
|
|
|
|
|
|
def crop(self, max_length: int) -> None:
|
|
|
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
|
|
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
|
|
|
assert max_length > 0, "max_length must be positive"
|
|
|
|
|
|
if self.get_seq_length() <= max_length:
|
|
|
return
|
|
|
|
|
|
for layer_idx in range(len(self.key_cache)):
|
|
|
if self.key_cache[layer_idx].numel():
|
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...]
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...]
|
|
|
|
|
|
def batch_repeat_interleave(self, repeats: int) -> None:
|
|
|
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
|
|
for layer_idx in range(len(self.key_cache)):
|
|
|
if self.key_cache[layer_idx].numel():
|
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
|
|
|
|
|
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
|
|
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
|
|
for layer_idx in range(len(self.key_cache)):
|
|
|
if self.key_cache[layer_idx].numel():
|
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
|
|
|
|
|
|
|
|
class KVCache:
|
|
|
def __init__(self, cache_size: int = 4) -> None:
|
|
|
self.cache_size = cache_size
|
|
|
self.tensor_input_field_names = [
|
|
|
"input_ids",
|
|
|
"within_seq_position_ids",
|
|
|
"global_position_ids",
|
|
|
"sequence_ids",
|
|
|
"labels",
|
|
|
]
|
|
|
self.tensor_output_field_names = ["logits", "embeddings"]
|
|
|
self.cache_dict: dict[str, DynamicCache] = {}
|
|
|
self.cache_queue: list[str] = []
|
|
|
|
|
|
def reset(self) -> None:
|
|
|
for k in list(self.cache_dict.keys()):
|
|
|
del self.cache_dict[k]
|
|
|
del self.cache_dict
|
|
|
self.cache_dict = {}
|
|
|
self.cache_queue = []
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def before_forward(self, batch: dict[str, torch.Tensor]) -> None:
|
|
|
contexts: list[str] | None = batch.get("context", None)
|
|
|
if contexts is None or "context_len" not in batch:
|
|
|
logger.warning_once(
|
|
|
"KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping."
|
|
|
)
|
|
|
return
|
|
|
|
|
|
context_lens: list[int] = list(set(batch["context_len"]))
|
|
|
contexts: list[str] = list(set(contexts))
|
|
|
if len(contexts) != 1 or len(context_lens) != 1:
|
|
|
logger.warning(
|
|
|
"SingleContextKVCache requires a single context and context length. "
|
|
|
"Multiple contexts or context lengths found in a single batch. Skipping."
|
|
|
)
|
|
|
return
|
|
|
|
|
|
batch_size = batch["input_ids"].shape[0]
|
|
|
|
|
|
unique_context = contexts[0]
|
|
|
unique_context_len = context_lens[0]
|
|
|
batch["use_cache"] = True
|
|
|
|
|
|
if unique_context not in self.cache_dict:
|
|
|
return
|
|
|
|
|
|
self.cache_dict[unique_context].batch_repeat_interleave(batch_size)
|
|
|
past_key_values = self.cache_dict[unique_context]
|
|
|
batch["past_key_values"] = past_key_values
|
|
|
|
|
|
|
|
|
for field_name in self.tensor_input_field_names:
|
|
|
if batch.get(field_name, None) is not None:
|
|
|
batch[field_name] = batch[field_name][:, unique_context_len:]
|
|
|
|
|
|
def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None:
|
|
|
contexts = batch.get("context", None)
|
|
|
context_lens = batch.get("context_len", [])
|
|
|
if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0:
|
|
|
return
|
|
|
|
|
|
assert batch["use_cache"]
|
|
|
unique_context = contexts[0]
|
|
|
unique_context_len = context_lens[0]
|
|
|
|
|
|
past_key_values = getattr(outputs, "past_key_values", None)
|
|
|
if not isinstance(past_key_values, DynamicCache):
|
|
|
logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.")
|
|
|
return
|
|
|
|
|
|
if "past_key_values" not in batch:
|
|
|
if len(self.cache_queue) == self.cache_size:
|
|
|
last_context = self.cache_queue.pop(0)
|
|
|
if last_context not in self.cache_queue:
|
|
|
del self.cache_dict[last_context]
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
self.cache_dict[unique_context] = past_key_values
|
|
|
self.cache_queue.append(unique_context)
|
|
|
|
|
|
|
|
|
for field_name in self.tensor_input_field_names:
|
|
|
if field_name in batch and batch[field_name] is not None:
|
|
|
batch[field_name] = batch[field_name][:, unique_context_len:]
|
|
|
|
|
|
|
|
|
for field_name in self.tensor_output_field_names:
|
|
|
if field_name in outputs and outputs[field_name] is not None:
|
|
|
outputs[field_name] = outputs[field_name][:, unique_context_len:]
|
|
|
if "hidden_states" in outputs and outputs["hidden_states"] is not None:
|
|
|
outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]]
|
|
|
|
|
|
self.cache_dict[unique_context].crop(unique_context_len)
|
|
|
self.cache_dict[unique_context].batch_select_indices([0])
|
|
|
|
|
|
|
|
|
class AttentionMethod(Enum):
|
|
|
FLASH = "flash"
|
|
|
FLEX = "flex"
|
|
|
|
|
|
|
|
|
class AttentionLayerType(Enum):
|
|
|
WITHIN_SEQ = "within_seq"
|
|
|
GLOBAL = "global"
|
|
|
|
|
|
|
|
|
class AttentionArgs(TypedDict, total=False):
|
|
|
flex_attention_args: FlexAttentionArgs
|
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
|
|
|
|
|
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch,
|
|
|
num_attention_heads, seqlen, head_dim)
|
|
|
"""
|
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
|
if n_rep == 1:
|
|
|
return hidden_states
|
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module):
|
|
|
def __init__(
|
|
|
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.dim = dim
|
|
|
self.base = base
|
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
|
|
|
|
|
self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
|
|
|
|
|
|
@staticmethod
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Rotates half the hidden dims of the input."""
|
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None:
|
|
|
|
|
|
self.max_seq_len_cached = seq_len
|
|
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
|
|
angles = torch.outer(t, self.inv_freq.to(device))
|
|
|
angles = torch.cat((angles, angles), dim=1)
|
|
|
self.register_buffer("cos_cached", angles.cos(), persistent=False)
|
|
|
self.register_buffer("sin_cached", angles.sin(), persistent=False)
|
|
|
|
|
|
def forward(
|
|
|
self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
device, dtype = q.device, q.dtype
|
|
|
seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len
|
|
|
|
|
|
if seq_len > self.max_seq_len_cached:
|
|
|
self._set_sin_cos_cache(seq_len=seq_len, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
idxs = position_ids.to(device)
|
|
|
cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
|
|
|
sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
|
|
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
|
|
return q_embed, k_embed
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper."""
|
|
|
|
|
|
def __init__(self, config: E1Config, layer_idx: int):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
self.layer_idx = layer_idx
|
|
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
self.num_heads = config.num_attention_heads
|
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
self.num_kv_heads = config.num_key_value_heads
|
|
|
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
|
|
self.max_num_seqs = config.max_num_sequences
|
|
|
self.clip_qkv = config.clip_qkv
|
|
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
raise ValueError(
|
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
|
f" and `num_heads`: {self.num_heads})."
|
|
|
)
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
|
|
if self.config.global_attention_every_n_layers > 0:
|
|
|
self.layer_type = (
|
|
|
AttentionLayerType.GLOBAL
|
|
|
if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0
|
|
|
else AttentionLayerType.WITHIN_SEQ
|
|
|
)
|
|
|
else:
|
|
|
self.layer_type = AttentionLayerType.WITHIN_SEQ
|
|
|
|
|
|
self.rope_theta = (
|
|
|
config.rope_theta_within_seq
|
|
|
if self.layer_type == AttentionLayerType.WITHIN_SEQ
|
|
|
else config.rope_theta_global
|
|
|
)
|
|
|
self.max_position_embeddings = (
|
|
|
config.max_num_positions_within_seq
|
|
|
if self.layer_type == AttentionLayerType.WITHIN_SEQ
|
|
|
else config.max_num_positions_global
|
|
|
)
|
|
|
|
|
|
self.rotary_emb = RotaryPositionalEmbedding(
|
|
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
|
|
|
)
|
|
|
|
|
|
def prepare_qkv(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
position_ids: torch.LongTensor,
|
|
|
past_key_value: DynamicCache | None = None,
|
|
|
use_cache: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
query_states: torch.Tensor = self.q_proj(hidden_states)
|
|
|
key_states: torch.Tensor = self.k_proj(hidden_states)
|
|
|
val_states: torch.Tensor = self.v_proj(hidden_states)
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
|
|
|
val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
|
|
|
|
|
|
if self.clip_qkv is not None:
|
|
|
query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv)
|
|
|
key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv)
|
|
|
val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv)
|
|
|
|
|
|
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
|
|
|
|
|
|
if use_cache and past_key_value is not None:
|
|
|
key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype
|
|
|
if torch.is_autocast_enabled():
|
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
|
else:
|
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
if input_dtype != target_dtype:
|
|
|
logger.warning_once(
|
|
|
f"The input hidden states seems to be silently casted in {input_dtype}. "
|
|
|
f"This might be because you have upcasted embedding or layer norm layers "
|
|
|
f"in {input_dtype}. We will cast back the input in {target_dtype}."
|
|
|
)
|
|
|
query_states = query_states.to(target_dtype)
|
|
|
key_states = key_states.to(target_dtype)
|
|
|
val_states = val_states.to(target_dtype)
|
|
|
|
|
|
return query_states, key_states, val_states
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
past_key_value: DynamicCache | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
use_cache: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]:
|
|
|
is_cache_prefilled = (
|
|
|
use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0
|
|
|
)
|
|
|
|
|
|
query_states, key_states, val_states = self.prepare_qkv(
|
|
|
hidden_states=hidden_states,
|
|
|
position_ids=within_seq_position_ids
|
|
|
if self.layer_type == AttentionLayerType.WITHIN_SEQ
|
|
|
else global_position_ids,
|
|
|
past_key_value=past_key_value,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.layer_type == AttentionLayerType.WITHIN_SEQ or is_cache_prefilled:
|
|
|
attention_type = AttentionMethod.FLASH
|
|
|
else:
|
|
|
attention_type = AttentionMethod.FLEX
|
|
|
|
|
|
attn_output, attn_weights = self._attn(
|
|
|
attention_type=attention_type,
|
|
|
query_states=query_states,
|
|
|
key_states=key_states,
|
|
|
val_states=val_states,
|
|
|
sequence_ids=sequence_ids,
|
|
|
attention_args=attention_args,
|
|
|
output_attentions=output_attentions,
|
|
|
)
|
|
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def _attn(
|
|
|
self,
|
|
|
attention_type: AttentionMethod,
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
val_states: torch.Tensor,
|
|
|
sequence_ids: torch.Tensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
match attention_type:
|
|
|
case AttentionMethod.FLASH:
|
|
|
f = self._flash_attn
|
|
|
case AttentionMethod.FLEX:
|
|
|
f = self._flex_attn
|
|
|
case _:
|
|
|
raise ValueError(f"No attention implementation found for {attention_type}")
|
|
|
return f(
|
|
|
query_states=query_states,
|
|
|
key_states=key_states,
|
|
|
val_states=val_states,
|
|
|
sequence_ids=sequence_ids,
|
|
|
attention_args=attention_args,
|
|
|
output_attentions=output_attentions,
|
|
|
)
|
|
|
|
|
|
def _flash_attn(
|
|
|
self,
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
val_states: torch.Tensor,
|
|
|
sequence_ids: torch.Tensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
"""Flash attention implementation.
|
|
|
|
|
|
Calls the public API of flash attention and deals with padding tokens if any are present.
|
|
|
"""
|
|
|
assert not output_attentions, "Flash attention doesn't support returning attention masks"
|
|
|
bsz, q_len = query_states.shape[0], query_states.shape[1]
|
|
|
_, kv_len = key_states.shape[0], key_states.shape[1]
|
|
|
|
|
|
if self.layer_type == AttentionLayerType.GLOBAL:
|
|
|
q_sequence_ids = sequence_ids
|
|
|
if q_len < kv_len:
|
|
|
|
|
|
|
|
|
first_token_id = sequence_ids[:, 0].unsqueeze(1)
|
|
|
k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1)
|
|
|
else:
|
|
|
k_sequence_ids = sequence_ids
|
|
|
else:
|
|
|
if q_len < kv_len:
|
|
|
key_states = key_states[:, -q_len:]
|
|
|
val_states = val_states[:, -q_len:]
|
|
|
q_sequence_ids = k_sequence_ids = sequence_ids
|
|
|
|
|
|
if is_flash_attention_available():
|
|
|
attn_output = flash_attention_func(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
val_states,
|
|
|
q_sequence_ids=q_sequence_ids,
|
|
|
k_sequence_ids=k_sequence_ids,
|
|
|
causal=False,
|
|
|
)
|
|
|
else:
|
|
|
attn_output = varlen_flex_attention_func(
|
|
|
query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids
|
|
|
)
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
|
return attn_output, None
|
|
|
|
|
|
def _flex_attn(
|
|
|
self,
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
val_states: torch.Tensor,
|
|
|
sequence_ids: torch.Tensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
bsz, q_len = query_states.shape[0], query_states.shape[1]
|
|
|
flex_attention_args = attention_args.get("flex_attention_args", None) if attention_args is not None else None
|
|
|
block_mask = flex_attention_args.get("block_mask", None) if flex_attention_args is not None else None
|
|
|
score_mod = flex_attention_args.get("score_mod", None) if flex_attention_args is not None else None
|
|
|
outputs = flex_attention_func(query_states, key_states, val_states, score_mod=score_mod, block_mask=block_mask)
|
|
|
|
|
|
outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
|
return outputs, None
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
def __init__(self, config: E1Config):
|
|
|
super().__init__()
|
|
|
self.ffn_dim = config.intermediate_size
|
|
|
self.hidden_dim = config.hidden_size
|
|
|
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
|
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
return self.w2(self.act_fn(self.w1(hidden_states)))
|
|
|
|
|
|
|
|
|
class GLUMLP(nn.Module):
|
|
|
def __init__(self, config: E1Config):
|
|
|
super().__init__()
|
|
|
self.ffn_dim = config.intermediate_size
|
|
|
self.hidden_dim = config.hidden_size
|
|
|
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
|
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
|
|
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
|
|
hidden_states = self.w2(hidden_states)
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
class FFN(nn.Module):
|
|
|
def __init__(self, config: E1Config):
|
|
|
super().__init__()
|
|
|
mlp_cls = GLUMLP if config.gated_mlp else MLP
|
|
|
self.mlp = mlp_cls(config)
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
return self.mlp(hidden_states)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class E1ModelOutputWithPast(ModelOutput):
|
|
|
"""Base class for model's outputs, with potential hidden states and attentions.
|
|
|
|
|
|
Attributes:
|
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
|
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
|
|
encoder_sequence_length, embed_size_per_head)`.
|
|
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
|
|
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
|
|
input) to speed up sequential decoding.
|
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
|
sequence_length)`.
|
|
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
|
heads.
|
|
|
"""
|
|
|
|
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
|
past_key_values: DynamicCache | None = None
|
|
|
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
|
attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class E1MaskedLMOutputWithPast(ModelOutput):
|
|
|
loss: torch.FloatTensor | None = None
|
|
|
mlm_loss: torch.FloatTensor | None = None
|
|
|
logits: torch.FloatTensor | None = None
|
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
|
past_key_values: DynamicCache | None = None
|
|
|
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
|
attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class E1ClassificationOutputWithPast(ModelOutput):
|
|
|
loss: torch.FloatTensor | None = None
|
|
|
logits: torch.FloatTensor | None = None
|
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
|
past_key_values: DynamicCache | None = None
|
|
|
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
|
|
attentions: tuple[torch.FloatTensor, ...] | None = None
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
self.variance_epsilon = eps
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
input_dtype = hidden_states.dtype
|
|
|
if layer_norm is None:
|
|
|
return torch.nn.functional.rms_norm(
|
|
|
hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon
|
|
|
).to(input_dtype)
|
|
|
else:
|
|
|
return layer_norm.rms_norm_fn(
|
|
|
x=hidden_states,
|
|
|
weight=self.weight,
|
|
|
bias=None,
|
|
|
residual=None,
|
|
|
eps=self.variance_epsilon,
|
|
|
dropout_p=0.0,
|
|
|
prenorm=False,
|
|
|
residual_in_fp32=False,
|
|
|
).to(input_dtype)
|
|
|
|
|
|
|
|
|
class NormAttentionNorm(nn.Module):
|
|
|
def __init__(self, config: E1Config, layer_idx: int):
|
|
|
super().__init__()
|
|
|
self.self_attn = Attention(config, layer_idx)
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
past_key_value: DynamicCache | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
use_cache: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None]:
|
|
|
residual = hidden_states
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
hidden_states=hidden_states,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
attention_args=attention_args,
|
|
|
past_key_value=past_key_value,
|
|
|
output_attentions=output_attentions,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
residual = hidden_states
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
return hidden_states, residual, self_attn_weights, present_key_value
|
|
|
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
|
def __init__(self, config: E1Config, layer_idx: int):
|
|
|
super().__init__()
|
|
|
self.initializer_range = config.initializer_range
|
|
|
self.hidden_size = config.hidden_size
|
|
|
self.norm_attn_norm = NormAttentionNorm(config, layer_idx)
|
|
|
self.ffn = FFN(config)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
attention_args: AttentionArgs | None = None,
|
|
|
past_key_value: DynamicCache | None = None,
|
|
|
output_attentions: bool = False,
|
|
|
use_cache: bool = False,
|
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]:
|
|
|
hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm(
|
|
|
hidden_states=hidden_states,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
attention_args=attention_args,
|
|
|
past_key_value=past_key_value,
|
|
|
output_attentions=output_attentions,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
|
|
|
|
|
|
hidden_states = self.ffn(hidden_states)
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
return hidden_states, self_attn_weights, present_key_value
|
|
|
|
|
|
|
|
|
|
|
|
class Pooler:
|
|
|
def __init__(self, pooling_types: List[str]):
|
|
|
self.pooling_types = pooling_types
|
|
|
self.pooling_options = {
|
|
|
'mean': self.mean_pooling,
|
|
|
'max': self.max_pooling,
|
|
|
'norm': self.norm_pooling,
|
|
|
'median': self.median_pooling,
|
|
|
'std': self.std_pooling,
|
|
|
'var': self.var_pooling,
|
|
|
'cls': self.cls_pooling,
|
|
|
'parti': self._pool_parti,
|
|
|
}
|
|
|
|
|
|
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
|
|
|
maxed_attentions = torch.max(attentions, dim=1)[0]
|
|
|
return maxed_attentions
|
|
|
|
|
|
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
|
|
|
|
|
|
|
|
|
|
|
|
G = self._convert_to_graph(attention_matrix)
|
|
|
if G.number_of_nodes() != attention_matrix.shape[0]:
|
|
|
raise Exception(
|
|
|
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
|
|
|
if G.number_of_edges() == 0:
|
|
|
raise Exception(f"You don't seem to have any attention edges left in the graph.")
|
|
|
|
|
|
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
|
|
|
|
|
|
def _convert_to_graph(self, matrix):
|
|
|
|
|
|
|
|
|
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
|
|
|
return G
|
|
|
|
|
|
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
for k in list(dict_importance.keys()):
|
|
|
if attention_mask[k] == 0:
|
|
|
del dict_importance[k]
|
|
|
|
|
|
|
|
|
|
|
|
total = sum(dict_importance.values())
|
|
|
return np.array([v / total for _, v in dict_importance.items()])
|
|
|
|
|
|
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
|
|
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
|
|
|
|
|
|
emb_pooled = []
|
|
|
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
|
|
|
dict_importance = self._page_rank(a)
|
|
|
importance_weights = self._calculate_importance_weights(dict_importance, mask)
|
|
|
num_tokens = int(mask.sum().item())
|
|
|
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
|
|
|
pooled = torch.tensor(np.array(emb_pooled))
|
|
|
return pooled
|
|
|
|
|
|
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.mean(dim=1)
|
|
|
else:
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
|
|
|
|
|
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.max(dim=1).values
|
|
|
else:
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
return (emb * attention_mask).max(dim=1).values
|
|
|
|
|
|
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.norm(dim=1, p=2)
|
|
|
else:
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
return (emb * attention_mask).norm(dim=1, p=2)
|
|
|
|
|
|
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.median(dim=1).values
|
|
|
else:
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
return (emb * attention_mask).median(dim=1).values
|
|
|
|
|
|
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.std(dim=1)
|
|
|
else:
|
|
|
|
|
|
var = self.var_pooling(emb, attention_mask, **kwargs)
|
|
|
return torch.sqrt(var)
|
|
|
|
|
|
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
if attention_mask is None:
|
|
|
return emb.var(dim=1)
|
|
|
else:
|
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(-1)
|
|
|
|
|
|
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
|
|
mean = mean.unsqueeze(1)
|
|
|
|
|
|
squared_diff = (emb - mean) ** 2
|
|
|
|
|
|
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
|
|
return var
|
|
|
|
|
|
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs):
|
|
|
return emb[:, 0, :]
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
emb: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
attentions: Optional[torch.Tensor] = None
|
|
|
):
|
|
|
final_emb = []
|
|
|
for pooling_type in self.pooling_types:
|
|
|
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
|
|
|
return torch.cat(final_emb, dim=-1)
|
|
|
|
|
|
|
|
|
class EmbeddingMixin:
|
|
|
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@property
|
|
|
def device(self) -> torch.device:
|
|
|
"""Get the device of the model."""
|
|
|
return next(self.parameters()).device
|
|
|
|
|
|
def _read_sequences_from_db(self, db_path: str) -> set[str]:
|
|
|
"""Read sequences from SQLite database."""
|
|
|
import sqlite3
|
|
|
sequences = []
|
|
|
with sqlite3.connect(db_path) as conn:
|
|
|
c = conn.cursor()
|
|
|
c.execute("SELECT sequence FROM embeddings")
|
|
|
while True:
|
|
|
row = c.fetchone()
|
|
|
if row is None:
|
|
|
break
|
|
|
sequences.append(row[0])
|
|
|
return set(sequences)
|
|
|
|
|
|
def embed_dataset(
|
|
|
self,
|
|
|
sequences: List[str],
|
|
|
|
|
|
batch_size: int = 2,
|
|
|
max_len: int = 512,
|
|
|
truncate: bool = True,
|
|
|
full_embeddings: bool = False,
|
|
|
embed_dtype: torch.dtype = torch.float32,
|
|
|
pooling_types: List[str] = ['mean'],
|
|
|
sql: bool = False,
|
|
|
save: bool = True,
|
|
|
sql_db_path: str = 'embeddings.db',
|
|
|
save_path: str = 'embeddings.pth',
|
|
|
**kwargs,
|
|
|
) -> Optional[dict[str, torch.Tensor]]:
|
|
|
"""Embed a dataset of protein sequences.
|
|
|
|
|
|
Args:
|
|
|
sequences: List of protein sequences
|
|
|
batch_size: Batch size for processing
|
|
|
max_len: Maximum sequence length
|
|
|
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
|
|
pooling_type: Type of pooling ('mean' or 'cls')
|
|
|
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
|
|
sql_db_path: Path to SQLite database
|
|
|
|
|
|
Returns:
|
|
|
Dictionary mapping sequences to embeddings, or None if sql=True
|
|
|
|
|
|
Note:
|
|
|
- If sql=True, embeddings can only be stored in float32
|
|
|
- sql is ideal if you need to stream a very large dataset for training in real-time
|
|
|
- save=True is ideal if you can store the entire embedding dictionary in RAM
|
|
|
- sql will be used if it is True and save is True or False
|
|
|
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
|
|
|
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
|
|
|
|
|
|
Example:
|
|
|
>>> embedder = EmbeddingMixin()
|
|
|
>>> embedding_dict = embedder.embed_dataset(
|
|
|
sequences=[
|
|
|
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
|
|
|
],
|
|
|
batch_size=2, # adjust for your GPU memory
|
|
|
max_len=512, # adjust for your needs
|
|
|
full_embeddings=False, # if True, no pooling is performed
|
|
|
embed_dtype=torch.float32, # cast to what dtype you want
|
|
|
pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
|
|
|
sql=False, # if True, embeddings will be stored in SQLite database
|
|
|
sql_db_path='embeddings.db',
|
|
|
save=True, # if True, embeddings will be saved as a .pth file
|
|
|
save_path='embeddings.pth',
|
|
|
)
|
|
|
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
|
|
"""
|
|
|
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
|
|
sequences = sorted(sequences, key=len, reverse=True)
|
|
|
hidden_size = self.config.hidden_size
|
|
|
pooler = Pooler(pooling_types) if not full_embeddings else None
|
|
|
|
|
|
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
if full_embeddings or residue_embeddings.ndim == 2:
|
|
|
return residue_embeddings
|
|
|
else:
|
|
|
return pooler(residue_embeddings, attention_mask)
|
|
|
|
|
|
if sql:
|
|
|
import sqlite3
|
|
|
conn = sqlite3.connect(sql_db_path)
|
|
|
c = conn.cursor()
|
|
|
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
|
|
|
already_embedded = self._read_sequences_from_db(sql_db_path)
|
|
|
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
|
|
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
|
|
print(f"Embedding {len(to_embed)} new sequences")
|
|
|
if len(to_embed) > 0:
|
|
|
with torch.no_grad():
|
|
|
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
|
|
seqs = to_embed[batch_start:batch_start + batch_size]
|
|
|
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
|
|
embeddings = get_embeddings(input_ids, attention_mask).float()
|
|
|
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
|
if full_embeddings:
|
|
|
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
|
|
|
conn.commit()
|
|
|
conn.commit()
|
|
|
conn.close()
|
|
|
return None
|
|
|
|
|
|
embeddings_dict = {}
|
|
|
if os.path.exists(save_path):
|
|
|
embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
|
|
|
to_embed = [seq for seq in sequences if seq not in embeddings_dict]
|
|
|
print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
|
|
|
print(f"Embedding {len(to_embed)} new sequences")
|
|
|
else:
|
|
|
to_embed = sequences
|
|
|
print(f"Embedding {len(to_embed)} new sequences")
|
|
|
|
|
|
if len(to_embed) > 0:
|
|
|
with torch.no_grad():
|
|
|
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
|
|
seqs = to_embed[batch_start:batch_start + batch_size]
|
|
|
last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
|
|
|
embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
|
|
|
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
|
if full_embeddings:
|
|
|
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
embeddings_dict[seq] = emb.cpu()
|
|
|
|
|
|
if save:
|
|
|
torch.save(embeddings_dict, save_path)
|
|
|
|
|
|
return embeddings_dict
|
|
|
|
|
|
|
|
|
class E1PreTrainedModel(PreTrainedModel):
|
|
|
config_class = E1Config
|
|
|
config: E1Config
|
|
|
base_model_prefix = "model"
|
|
|
supports_gradient_checkpointing = True
|
|
|
_no_split_modules = ["DecoderLayer"]
|
|
|
_transformer_layer_cls = [DecoderLayer]
|
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None:
|
|
|
std = self.config.initializer_range
|
|
|
if isinstance(module, nn.Linear):
|
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
|
if module.bias is not None:
|
|
|
module.bias.data.zero_()
|
|
|
elif isinstance(module, nn.Embedding):
|
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
|
if module.padding_idx is not None:
|
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
elif isinstance(module, RMSNorm):
|
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
def post_init(self) -> None:
|
|
|
super().post_init()
|
|
|
|
|
|
def _backward_compatibility_gradient_checkpointing(self) -> None:
|
|
|
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
|
|
self.gradient_checkpointing_enable(dict(use_reentrant=False))
|
|
|
|
|
|
@property
|
|
|
def _device(self) -> torch.device:
|
|
|
return next(self.parameters()).device
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(
|
|
|
cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs
|
|
|
) -> "E1PreTrainedModel":
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
|
|
|
|
|
|
|
|
class E1Model(E1PreTrainedModel, EmbeddingMixin):
|
|
|
config: E1Config
|
|
|
config_class = E1Config
|
|
|
def __init__(self, config: E1Config, **kwargs):
|
|
|
E1PreTrainedModel.__init__(self, config, **kwargs)
|
|
|
self.padding_idx = config.pad_token_id
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size)
|
|
|
self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
self.gradient_checkpointing = config.gradient_checkpointing
|
|
|
self.prep_tokens = E1BatchPreparer()
|
|
|
self.post_init()
|
|
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
|
return self.embed_tokens
|
|
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
|
|
self.embed_tokens = value
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
|
|
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
|
|
last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
if return_attention_mask:
|
|
|
attention_mask = (batch['sequence_ids'] != -1).long()
|
|
|
return last_hidden_state, attention_mask
|
|
|
else:
|
|
|
return last_hidden_state
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
past_key_values: DynamicCache | None = None,
|
|
|
use_cache: bool = False,
|
|
|
output_attentions: bool = False,
|
|
|
output_hidden_states: bool = False,
|
|
|
**kwargs
|
|
|
) -> E1ModelOutputWithPast:
|
|
|
"""
|
|
|
Args:
|
|
|
input_ids: (batch_size, seq_length)
|
|
|
within_seq_position_ids: (batch_size, seq_length)
|
|
|
This tensor contains the position of each residue within the sequence itself.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
|
|
|
the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
|
|
|
global_position_ids: (batch_size, seq_length)
|
|
|
This tensor contains the position of each residue within the global sequence.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
|
|
|
the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
|
|
|
sequence_ids: (batch_size, seq_length)
|
|
|
This tensor contains the sequence id of each residue.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
|
|
|
the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
|
|
|
past_key_values: DynamicCache
|
|
|
use_cache: bool
|
|
|
output_attentions: bool
|
|
|
output_hidden_states: bool
|
|
|
|
|
|
Returns:
|
|
|
E1ModelOutputWithPast: Model Outputs
|
|
|
"""
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
|
|
if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
|
|
|
if use_cache:
|
|
|
logger.warning_once(
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
)
|
|
|
use_cache = False
|
|
|
|
|
|
if use_cache and past_key_values is None:
|
|
|
past_key_values = DynamicCache()
|
|
|
elif not use_cache:
|
|
|
|
|
|
past_key_values = None
|
|
|
|
|
|
global_position_ids = global_position_ids.view(-1, seq_length).long()
|
|
|
within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long()
|
|
|
sequence_ids = sequence_ids.view(-1, seq_length).long()
|
|
|
|
|
|
max_position_id = torch.max(within_seq_position_ids).item()
|
|
|
min_position_id = torch.min(within_seq_position_ids).item()
|
|
|
assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, (
|
|
|
f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}"
|
|
|
)
|
|
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
|
inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0))
|
|
|
|
|
|
|
|
|
if torch.is_autocast_enabled():
|
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
|
else:
|
|
|
target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype
|
|
|
hidden_states = inputs_embeds.to(target_dtype)
|
|
|
|
|
|
|
|
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
|
|
|
|
|
|
attention_args: AttentionArgs | None = None
|
|
|
if past_key_values_length == 0:
|
|
|
block_mask = create_block_causal_mask_optimized(sequence_ids)
|
|
|
flex_attention_args = FlexAttentionArgs(block_mask=block_mask)
|
|
|
attention_args = AttentionArgs(flex_attention_args=flex_attention_args)
|
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
all_self_attns = () if output_attentions else None
|
|
|
next_decoder_cache = None
|
|
|
|
|
|
for decoder_layer in self.layers:
|
|
|
if output_hidden_states:
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
|
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
|
decoder_layer.__call__,
|
|
|
hidden_states,
|
|
|
within_seq_position_ids,
|
|
|
global_position_ids,
|
|
|
sequence_ids,
|
|
|
attention_args,
|
|
|
past_key_values,
|
|
|
output_attentions,
|
|
|
use_cache,
|
|
|
)
|
|
|
else:
|
|
|
layer_outputs = decoder_layer(
|
|
|
hidden_states,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
attention_args=attention_args,
|
|
|
past_key_value=past_key_values,
|
|
|
output_attentions=output_attentions,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = layer_outputs
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
|
|
|
|
|
|
|
next_decoder_cache = past_key_values = present_key_value
|
|
|
|
|
|
if output_attentions:
|
|
|
all_self_attns += (self_attn_weights,)
|
|
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
|
|
|
|
|
if output_hidden_states:
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
|
|
|
|
return E1ModelOutputWithPast(
|
|
|
last_hidden_state=hidden_states,
|
|
|
past_key_values=next_cache,
|
|
|
hidden_states=all_hidden_states,
|
|
|
attentions=all_self_attns,
|
|
|
)
|
|
|
|
|
|
|
|
|
class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
|
|
|
config: E1Config
|
|
|
config_class = E1Config
|
|
|
def __init__(self, config: E1Config, **kwargs):
|
|
|
E1PreTrainedModel.__init__(self, config, **kwargs)
|
|
|
self.model: E1Model = E1Model(config)
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.mlm_head = torch.nn.Sequential(
|
|
|
nn.Linear(config.hidden_size, config.hidden_size, bias=True),
|
|
|
nn.GELU(),
|
|
|
nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps),
|
|
|
nn.Linear(config.hidden_size, config.vocab_size, bias=True),
|
|
|
)
|
|
|
self.gradient_checkpointing = config.gradient_checkpointing
|
|
|
self.prep_tokens = E1BatchPreparer()
|
|
|
self.post_init()
|
|
|
|
|
|
@property
|
|
|
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
|
|
return self.model.device_mesh
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
|
|
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
|
|
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
if return_attention_mask:
|
|
|
attention_mask = (batch['sequence_ids'] != -1).long()
|
|
|
return last_hidden_state, attention_mask
|
|
|
else:
|
|
|
return last_hidden_state
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
labels: torch.LongTensor | None = None,
|
|
|
past_key_values: DynamicCache | None = None,
|
|
|
use_cache: bool = False,
|
|
|
output_attentions: bool = False,
|
|
|
output_hidden_states: bool = False,
|
|
|
**kwargs,
|
|
|
) -> E1MaskedLMOutputWithPast:
|
|
|
"""
|
|
|
Args:
|
|
|
input_ids: (batch_size, seq_length)
|
|
|
within_seq_position_ids: (batch_size, seq_length)
|
|
|
This tensor contains the position of each residue within the sequence itself.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
|
|
|
the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
|
|
|
global_position_ids: (batch_size, seq_length)
|
|
|
This tensor contains the position of each residue within the global sequence.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
|
|
|
the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
|
|
|
sequence_ids: (batch_size, seq_length)
|
|
|
This tensor contains the sequence id of each residue.
|
|
|
For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
|
|
|
the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
|
|
|
labels: (batch_size, seq_length)
|
|
|
past_key_values: DynamicCache
|
|
|
use_cache: bool
|
|
|
output_attentions: bool
|
|
|
output_hidden_states: bool
|
|
|
|
|
|
Returns:
|
|
|
E1MaskedLMOutputWithPast: Model Outputs
|
|
|
"""
|
|
|
outputs: E1ModelOutputWithPast = self.model(
|
|
|
input_ids=input_ids,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
past_key_values=past_key_values,
|
|
|
use_cache=use_cache,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
)
|
|
|
|
|
|
x = outputs.last_hidden_state
|
|
|
loss = None
|
|
|
|
|
|
|
|
|
mlm_logits = self.mlm_head(x).float()
|
|
|
mlm_loss = 0.0
|
|
|
if labels is not None:
|
|
|
mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size)
|
|
|
mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1)
|
|
|
mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none")
|
|
|
mask = mlm_labels_flat != self.model.padding_idx
|
|
|
n_mlm = mask.sum()
|
|
|
mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm)
|
|
|
loss = 0.0
|
|
|
loss += mlm_loss
|
|
|
|
|
|
return E1MaskedLMOutputWithPast(
|
|
|
loss=loss,
|
|
|
mlm_loss=mlm_loss,
|
|
|
logits=mlm_logits,
|
|
|
last_hidden_state=x,
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
attentions=outputs.attentions,
|
|
|
)
|
|
|
|
|
|
|
|
|
class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
|
|
|
config: E1Config
|
|
|
config_class = E1Config
|
|
|
def __init__(self, config: E1Config, **kwargs):
|
|
|
E1PreTrainedModel.__init__(self, config, **kwargs)
|
|
|
self.model: E1Model = E1Model(config)
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.num_labels = config.num_labels
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
|
|
|
nn.GELU(),
|
|
|
nn.LayerNorm(config.hidden_size * 4),
|
|
|
nn.Linear(config.hidden_size * 4, config.num_labels),
|
|
|
)
|
|
|
self.mse = nn.MSELoss()
|
|
|
self.ce = nn.CrossEntropyLoss()
|
|
|
self.bce = nn.BCEWithLogitsLoss()
|
|
|
self.gradient_checkpointing = config.gradient_checkpointing
|
|
|
self.prep_tokens = E1BatchPreparer()
|
|
|
|
|
|
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
|
|
|
pooling_types = kwargs['pooling_types']
|
|
|
else:
|
|
|
pooling_types = ['mean', 'var']
|
|
|
self.pooler = Pooler(pooling_types)
|
|
|
self.post_init()
|
|
|
|
|
|
@property
|
|
|
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
|
|
return self.model.device_mesh
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
|
|
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
|
|
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
if return_attention_mask:
|
|
|
attention_mask = (batch['sequence_ids'] != -1).long()
|
|
|
return last_hidden_state, attention_mask
|
|
|
else:
|
|
|
return last_hidden_state
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
labels: torch.LongTensor | None = None,
|
|
|
past_key_values: DynamicCache | None = None,
|
|
|
use_cache: bool = False,
|
|
|
output_attentions: bool = False,
|
|
|
output_hidden_states: bool = False,
|
|
|
**kwargs,
|
|
|
) -> E1ClassificationOutputWithPast:
|
|
|
outputs: E1ModelOutputWithPast = self.model(
|
|
|
input_ids=input_ids,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
past_key_values=past_key_values,
|
|
|
use_cache=use_cache,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
)
|
|
|
|
|
|
attention_mask = (sequence_ids != -1).long()
|
|
|
x = outputs.last_hidden_state
|
|
|
features = self.pooler(x, attention_mask)
|
|
|
logits = self.classifier(features)
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
labels = labels.to(logits.device)
|
|
|
if self.config.problem_type is None:
|
|
|
if self.num_labels == 1:
|
|
|
self.config.problem_type = "regression"
|
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
self.config.problem_type = "single_label_classification"
|
|
|
else:
|
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
|
|
if self.config.problem_type == "regression":
|
|
|
if self.num_labels == 1:
|
|
|
loss = self.mse(logits.flatten(), labels.flatten())
|
|
|
else:
|
|
|
loss = self.mse(logits, labels)
|
|
|
elif self.config.problem_type == "single_label_classification":
|
|
|
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
|
loss = self.bce(logits, labels)
|
|
|
|
|
|
return E1ClassificationOutputWithPast(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
last_hidden_state=x,
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
attentions=outputs.attentions,
|
|
|
)
|
|
|
|
|
|
|
|
|
class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
|
|
|
config: E1Config
|
|
|
config_class = E1Config
|
|
|
def __init__(self, config: E1Config, **kwargs):
|
|
|
E1PreTrainedModel.__init__(self, config, **kwargs)
|
|
|
self.model: E1Model = E1Model(config)
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.num_labels = config.num_labels
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
|
|
|
nn.GELU(),
|
|
|
nn.LayerNorm(config.hidden_size * 4),
|
|
|
nn.Linear(config.hidden_size * 4, config.num_labels),
|
|
|
)
|
|
|
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
self.gradient_checkpointing = config.gradient_checkpointing
|
|
|
self.prep_tokens = E1BatchPreparer()
|
|
|
self.post_init()
|
|
|
|
|
|
@property
|
|
|
def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
|
|
|
return self.model.device_mesh
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
|
|
|
batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
|
|
|
last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
if return_attention_mask:
|
|
|
attention_mask = (batch['sequence_ids'] != -1).long()
|
|
|
return last_hidden_state, attention_mask
|
|
|
else:
|
|
|
return last_hidden_state
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
within_seq_position_ids: torch.LongTensor,
|
|
|
global_position_ids: torch.LongTensor,
|
|
|
sequence_ids: torch.LongTensor,
|
|
|
labels: torch.LongTensor | None = None,
|
|
|
past_key_values: DynamicCache | None = None,
|
|
|
use_cache: bool = False,
|
|
|
output_attentions: bool = False,
|
|
|
output_hidden_states: bool = False,
|
|
|
**kwargs,
|
|
|
) -> E1ClassificationOutputWithPast:
|
|
|
outputs: E1ModelOutputWithPast = self.model(
|
|
|
input_ids=input_ids,
|
|
|
within_seq_position_ids=within_seq_position_ids,
|
|
|
global_position_ids=global_position_ids,
|
|
|
sequence_ids=sequence_ids,
|
|
|
past_key_values=past_key_values,
|
|
|
use_cache=use_cache,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
)
|
|
|
|
|
|
x = outputs.last_hidden_state
|
|
|
logits = self.classifier(x)
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
|
|
return E1ClassificationOutputWithPast(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
last_hidden_state=x,
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
attentions=outputs.attentions,
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
model = E1ForSequenceClassification.from_pretrained("Profluent-Bio/E1-150m", dtype=torch.bfloat16, num_labels=1).eval().to(device)
|
|
|
print(model)
|
|
|
|
|
|
seqs = [
|
|
|
"MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE",
|
|
|
"IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY",
|
|
|
"PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL",
|
|
|
"SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA",
|
|
|
]
|
|
|
|
|
|
batch = model.prep_tokens.get_batch_kwargs(seqs, device=device)
|
|
|
batch['labels'] = torch.tensor([0.0, 0.0, 0.0, 0.0], device=device)
|
|
|
|
|
|
last_hidden_state = model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
|
|
|
print(last_hidden_state.shape)
|
|
|
|