lhallee commited on
Commit
f8b862f
·
verified ·
1 Parent(s): 2ad7446

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +2129 -0
modeling_e1.py ADDED
@@ -0,0 +1,2129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
3
+
4
+ import numpy as np
5
+ import networkx as nx
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset as TorchDataset, DataLoader
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ from einops import rearrange, repeat
13
+ from enum import Enum
14
+ from typing import Any, TypedDict, Callable, Optional, List
15
+ from dataclasses import dataclass
16
+ from tokenizers import Tokenizer
17
+ from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import ModelOutput
20
+ from transformers.utils import logging
21
+ from tqdm.auto import tqdm
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ ### Establish attention compatibility
27
+ try:
28
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
29
+ except ImportError:
30
+ logger.warning("Failed to import flash attention; Will be using PyTorch attention instead")
31
+ flash_attn_func = None
32
+ flash_attn_varlen_func = None
33
+
34
+ try:
35
+ from torch.nn.attention.flex_attention import (
36
+ BlockMask,
37
+ create_block_mask,
38
+ flex_attention,
39
+ _create_sparse_block_from_block_mask
40
+ )
41
+
42
+ if torch.cuda.is_available():
43
+ # if on linux, compile the flex attention function
44
+ if os.name == 'posix':
45
+ print("Compiling flex attention")
46
+ flex_attention = torch.compile(flex_attention, dynamic=True)
47
+ else:
48
+ print("Not compiling flex attention, detected non-Linux environment")
49
+
50
+ except ImportError:
51
+ logger.warning("Failed to import flex attention; Will be using PyTorch attention instead")
52
+ flex_attention = None
53
+
54
+ try:
55
+ from kernels import get_kernel
56
+ layer_norm = get_kernel("kernels-community/triton-layer-norm")
57
+ except Exception as e:
58
+ logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead")
59
+ layer_norm = None
60
+
61
+
62
+ def is_flash_attention_available() -> bool:
63
+ return (
64
+ flash_attn_func is not None and flash_attn_varlen_func is not None and (os.getenv("USE_FLASH_ATTN", "1") == "1")
65
+ )
66
+
67
+
68
+ class FlexAttentionArgs(TypedDict, total=False):
69
+ block_mask: BlockMask | None
70
+ score_mod: Callable | None
71
+
72
+
73
+ def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask:
74
+ # Assumes sequence_ids is sorted in increasing order for each batch item, except for
75
+ # the -1 values, which are used to indicate the padding tokens.
76
+ def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def]
77
+ return (
78
+ (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx])
79
+ & (sequence_ids[b, q_idx] != -1)
80
+ & (sequence_ids[b, kv_idx] != -1)
81
+ )
82
+
83
+ batch_size, seqlen = sequence_ids.shape
84
+ return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device)
85
+
86
+
87
+ def flex_attention_func(
88
+ query_states: torch.Tensor, # (bs, seqlen, nh, hs)
89
+ key_states: torch.Tensor, # (bs, seqlen, nkv, hs)
90
+ value_states: torch.Tensor, # (bs, seqlen, nkv, hs)
91
+ score_mod: Callable | None = None,
92
+ block_mask: BlockMask | None = None,
93
+ ) -> torch.Tensor:
94
+ assert flex_attention is not None, "Flex Attention is not available in this environment"
95
+ assert score_mod is None, "Score mod is not supported yet"
96
+ query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs)
97
+ key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs)
98
+ value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs)
99
+
100
+ outputs = flex_attention(
101
+ query_states,
102
+ key_states,
103
+ value_states,
104
+ block_mask=block_mask,
105
+ score_mod=score_mod,
106
+ enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh
107
+ )
108
+
109
+ outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs)
110
+ return outputs
111
+
112
+
113
+ def flash_attention_func(
114
+ query_states: torch.Tensor, # (bs, seqlen, nh, hs)
115
+ key_states: torch.Tensor, # (bs, seqlen, nkv, hs)
116
+ value_states: torch.Tensor, # (bs, seqlen, nkv, hs)
117
+ q_sequence_ids: torch.Tensor,
118
+ k_sequence_ids: torch.Tensor,
119
+ causal: bool = False,
120
+ ) -> torch.Tensor: # (bs, seqlen, nh, hs)
121
+ # Contains at least one padding token in the sequence. Note: ignore attention mask if causal.
122
+ if not is_flash_attention_available():
123
+ raise ImportError("Flash Attention is not available. Please install flash-attn.")
124
+
125
+ if not causal:
126
+ batch_size, q_len = query_states.shape[0], query_states.shape[1]
127
+ (
128
+ query_states,
129
+ key_states,
130
+ value_states,
131
+ indices_q,
132
+ (cu_seqlens_q, cu_seqlens_k),
133
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
134
+ ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
135
+
136
+ attn_output_unpad = flash_attn_varlen_func(
137
+ query_states,
138
+ key_states,
139
+ value_states,
140
+ cu_seqlens_q=cu_seqlens_q,
141
+ cu_seqlens_k=cu_seqlens_k,
142
+ max_seqlen_q=max_seqlen_in_batch_q,
143
+ max_seqlen_k=max_seqlen_in_batch_k,
144
+ causal=False,
145
+ )
146
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
147
+
148
+ else:
149
+ attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
150
+
151
+ return attn_output
152
+
153
+
154
+ class IndexFirstAxis(torch.autograd.Function):
155
+ @staticmethod
156
+ def forward(ctx, input, indices) -> torch.Tensor: # type: ignore[no-untyped-def]
157
+ ctx.save_for_backward(indices)
158
+ assert input.ndim >= 2
159
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
160
+ second_dim = other_shape.numel()
161
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
162
+ # return input[indices]
163
+ return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(
164
+ -1, *other_shape
165
+ )
166
+
167
+ @staticmethod
168
+ def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: # type: ignore[no-untyped-def]
169
+ (indices,) = ctx.saved_tensors
170
+ assert grad_output.ndim >= 2
171
+ other_shape = grad_output.shape[1:]
172
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
173
+ grad_input = torch.zeros(
174
+ [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
175
+ )
176
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
177
+ # grad_input[indices] = grad_output
178
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
179
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
180
+
181
+
182
+ def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
183
+ device = SLEN.device
184
+ total_tokens = torch.sum(SLEN)
185
+ B = (total_tokens + block_size - 1) // block_size
186
+ padding_tokens = B * block_size - total_tokens
187
+ SLEN = torch.cat([SLEN, torch.Tensor([padding_tokens]).to(device)], dim=0)
188
+
189
+ assert torch.sum(SLEN) == B * block_size
190
+
191
+ # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i
192
+ cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,)
193
+ total_tokens = cum[-1].item()
194
+
195
+ # Block start/end offsets [start, end) in token index space
196
+ block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,)
197
+ block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,)
198
+
199
+ # MIN_SEQ_ID[i] = first sequence whose end > block_start
200
+ # searchsorted with right=True returns first index where cum > value
201
+ MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True)
202
+
203
+ # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1)
204
+ # For empty tail beyond total_tokens we already clipped block_ends.
205
+ last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token
206
+ MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True)
207
+
208
+ return MIN_SEQ_ID, MAX_SEQ_ID
209
+
210
+
211
+ def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
+ MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q)
213
+ MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K)
214
+
215
+ cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0)
216
+ cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1)
217
+ overlap = cond1 & cond2
218
+
219
+ cond1 = (MIN_Q == MAX_Q).unsqueeze(1)
220
+ cond2 = (MIN_K == MAX_K).unsqueeze(0)
221
+ same_seq_in_qk = cond1 & cond2
222
+
223
+ full_blocks = overlap & same_seq_in_qk
224
+ partial_blocks = overlap & ~same_seq_in_qk
225
+
226
+ return full_blocks, partial_blocks
227
+
228
+
229
+ def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
230
+ full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K)
231
+ partial_blocks = partial_blocks[None, None]
232
+ full_blocks = full_blocks[None, None]
233
+
234
+ q_doc_id = torch.repeat_interleave(SLEN_Q)
235
+ k_doc_id = torch.repeat_interleave(SLEN_K)
236
+
237
+ def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
238
+ return q_doc_id[q_idx] == k_doc_id[kv_idx]
239
+
240
+ total_q_len = q_doc_id.shape[0]
241
+ total_k_len = k_doc_id.shape[0]
242
+
243
+ return _create_sparse_block_from_block_mask(
244
+ (partial_blocks, full_blocks),
245
+ doc_mask,
246
+ seq_lengths=(total_q_len, total_k_len),
247
+ Q_BLOCK_SIZE=128,
248
+ KV_BLOCK_SIZE=128,
249
+ )
250
+
251
+
252
+ def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask:
253
+ q_doc_id = torch.repeat_interleave(SLEN_Q)
254
+ k_doc_id = torch.repeat_interleave(SLEN_K)
255
+
256
+ def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor:
257
+ return q_doc_id[q_idx] == k_doc_id[kv_idx]
258
+
259
+ total_q_len = q_doc_id.shape[0]
260
+ total_k_len = k_doc_id.shape[0]
261
+
262
+ return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device)
263
+
264
+
265
+ def varlen_flex_attention_func(
266
+ query_states: torch.Tensor,
267
+ key_states: torch.Tensor,
268
+ value_states: torch.Tensor,
269
+ q_sequence_ids: torch.Tensor,
270
+ k_sequence_ids: torch.Tensor,
271
+ ) -> torch.Tensor:
272
+ batch_size, q_len = query_states.shape[0], query_states.shape[1]
273
+ (
274
+ query_states,
275
+ key_states,
276
+ value_states,
277
+ indices_q,
278
+ (cu_seqlens_q, cu_seqlens_k),
279
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
280
+ ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids)
281
+
282
+ query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous()
283
+ key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous()
284
+ value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous()
285
+
286
+ seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
287
+ seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
288
+ block_mask = block_mask_creator(seqlens_q, seqlens_k)
289
+
290
+ attn_output_unpad = flex_attention(
291
+ query_states,
292
+ key_states,
293
+ value_states,
294
+ block_mask=block_mask,
295
+ enable_gqa=query_states.shape[1] != key_states.shape[1],
296
+ )
297
+
298
+ attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len)
299
+
300
+ return attn_output
301
+
302
+
303
+ class IndexPutFirstAxis(torch.autograd.Function):
304
+ @staticmethod
305
+ def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: # type: ignore[no-untyped-def]
306
+ ctx.save_for_backward(indices)
307
+ assert indices.ndim == 1
308
+ assert values.ndim >= 2
309
+ output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
310
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
311
+ output[indices] = values
312
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
313
+ return output
314
+
315
+ @staticmethod
316
+ def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: # type: ignore[no-untyped-def]
317
+ (indices,) = ctx.saved_tensors
318
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
319
+ grad_values = grad_output[indices]
320
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
321
+ return grad_values, None, None
322
+
323
+
324
+ index_put_first_axis = IndexPutFirstAxis.apply
325
+
326
+
327
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
328
+ """
329
+ Arguments:
330
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
331
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
332
+ batch: int, batch size for the padded sequence.
333
+ seqlen: int, maximum sequence length for the padded sequence.
334
+ Return:
335
+ hidden_states: (batch, seqlen, ...)
336
+ """
337
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
338
+ # output[indices] = hidden_states
339
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
340
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
341
+
342
+
343
+ def _get_unpad_data(sequence_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
344
+ non_pad_indices = sequence_ids != -1
345
+ non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten()
346
+ sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5
347
+ sequence_ids = sequence_ids.flatten()[non_pad_indices]
348
+ _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True)
349
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
350
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
351
+ return non_pad_indices, cu_seqlens, max_seqlen_in_batch
352
+
353
+
354
+ def _unpad_input(
355
+ query_layer: torch.Tensor,
356
+ key_layer: torch.Tensor,
357
+ value_layer: torch.Tensor,
358
+ q_sequence_ids: torch.Tensor,
359
+ k_sequence_ids: torch.Tensor,
360
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
361
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
362
+ query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2]
363
+ assert query_layer.shape[:2] == q_sequence_ids.shape, (
364
+ f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}"
365
+ )
366
+ assert key_layer.shape[:2] == k_sequence_ids.shape, (
367
+ f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}"
368
+ )
369
+ assert query_length <= kv_seq_len, (
370
+ f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}"
371
+ )
372
+
373
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids)
374
+
375
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
376
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
377
+
378
+ if torch.equal(q_sequence_ids, k_sequence_ids):
379
+ indices_q = indices_k
380
+ cu_seqlens_q = cu_seqlens_k
381
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
382
+ else:
383
+ indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids)
384
+
385
+ query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q)
386
+
387
+ assert cu_seqlens_q.shape == cu_seqlens_k.shape, (
388
+ f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}"
389
+ )
390
+
391
+ return (
392
+ query_layer,
393
+ key_layer,
394
+ value_layer,
395
+ indices_q,
396
+ (cu_seqlens_q, cu_seqlens_k),
397
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
398
+ )
399
+
400
+
401
+ index_first_axis = IndexFirstAxis.apply
402
+ block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask
403
+ PAD_TOKEN_ID = 0
404
+
405
+
406
+ def get_tokenizer() -> Tokenizer:
407
+ fname = os.path.join(os.path.dirname(__file__), "tokenizer.json")
408
+ tokenizer: Tokenizer = Tokenizer.from_file(fname)
409
+ assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, (
410
+ f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}"
411
+ )
412
+
413
+ return tokenizer
414
+
415
+
416
+ @dataclass
417
+ class DataPrepConfig:
418
+ max_num_sequences: int = 512
419
+ max_num_positions_within_seq: int = 8192
420
+ remove_X_tokens: bool = False
421
+
422
+
423
+ def get_context(sequence: str) -> str | None:
424
+ if "," in sequence:
425
+ return sequence.rsplit(",", 1)[0]
426
+ return None
427
+
428
+
429
+ class E1BatchPreparer:
430
+ def __init__(
431
+ self,
432
+ data_prep_config: DataPrepConfig | None = None,
433
+ tokenizer: Tokenizer | None = None,
434
+ preserve_context_labels: bool = False,
435
+ ):
436
+ self.tokenizer = tokenizer or get_tokenizer()
437
+ self.data_prep_config = data_prep_config or DataPrepConfig()
438
+ self.pad_token_id = self.tokenizer.token_to_id("<pad>")
439
+ self.preserve_context_labels = preserve_context_labels
440
+ device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu")
441
+ self.boundary_token_ids = torch.tensor(
442
+ [self.tokenizer.token_to_id(token) for token in ["<bos>", "<eos>", "1", "2", "<pad>"]], device=device
443
+ ).long()
444
+ self.mask_token = "?" # nosec
445
+ self.mask_token_id = self.tokenizer.token_to_id(self.mask_token)
446
+ self.X_token_id = self.tokenizer.token_to_id("X")
447
+ self.vocab = self.tokenizer.get_vocab()
448
+
449
+ def get_batch_kwargs( # type: ignore[override]
450
+ self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False
451
+ ) -> dict[str, torch.Tensor | list[str] | list[int]]:
452
+ sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences]
453
+ return self.pad_encodings(sequence_encodings, device, non_blocking)
454
+
455
+ def pad_encodings(
456
+ self,
457
+ sequence_encodings: list[dict[str, torch.Tensor]],
458
+ device: torch.device = torch.device("cpu"),
459
+ non_blocking: bool = False,
460
+ ) -> dict[str, torch.Tensor | list[str] | list[int]]:
461
+ non_blocking = non_blocking and device.type == "cuda"
462
+ padded_encodings = {}
463
+ # Note: We use -1 as the padding value for sequence and position ids because the 0 value
464
+ # is a valid value for sequence and position ids. -1 is then used to distinguish valid
465
+ # tokens from padding tokens, for example, when doing padding/unpadding for flash attention.
466
+ for key, padding_value in {
467
+ "input_ids": self.pad_token_id,
468
+ "sequence_ids": -1,
469
+ "within_seq_position_ids": -1,
470
+ "global_position_ids": -1,
471
+ "labels": self.pad_token_id,
472
+ }.items():
473
+ padded_encodings[key] = pad_sequence(
474
+ [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value
475
+ ).to(device=device, dtype=torch.long, non_blocking=non_blocking)
476
+
477
+ padded_encodings["context"] = [enc["context"] for enc in sequence_encodings]
478
+ padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings]
479
+
480
+ return padded_encodings
481
+
482
+ def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]:
483
+ single_sequences = sequence.split(",")
484
+ if len(single_sequences) > self.data_prep_config.max_num_sequences:
485
+ raise ValueError(
486
+ f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}"
487
+ " in the provided multi-sequence instance. Please remove some homologous sequences before trying again."
488
+ )
489
+
490
+ single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences]
491
+
492
+ num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings]
493
+ input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings])
494
+ labels = torch.cat([x["labels"] for x in single_sequence_encodings])
495
+
496
+ within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings])
497
+ global_position_ids, ctx_len = [], 0
498
+ for encoding in single_sequence_encodings:
499
+ global_position_ids.append(encoding["position_ids"] + ctx_len)
500
+ ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1)
501
+ global_position_ids = torch.cat(global_position_ids)
502
+
503
+ sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens))
504
+
505
+ # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired
506
+ context_len = sum(num_tokens[:-1])
507
+ context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False)
508
+ if not self.preserve_context_labels:
509
+ labels[:context_len] = self.pad_token_id
510
+
511
+ assert (
512
+ input_ids.shape
513
+ == sequence_ids.shape
514
+ == within_seq_position_ids.shape
515
+ == global_position_ids.shape
516
+ == labels.shape
517
+ ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape"
518
+
519
+ assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length"
520
+
521
+ return {
522
+ "input_ids": input_ids,
523
+ "sequence_ids": sequence_ids,
524
+ "within_seq_position_ids": within_seq_position_ids,
525
+ "global_position_ids": global_position_ids,
526
+ "labels": labels,
527
+ "context": context,
528
+ "context_len": context_len,
529
+ }
530
+
531
+ def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]:
532
+ if not self.validate_sequence(sequence):
533
+ raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only")
534
+
535
+ if len(sequence) > self.data_prep_config.max_num_positions_within_seq:
536
+ raise ValueError(
537
+ f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}"
538
+ )
539
+
540
+ # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"<bos>1{sequence}2<eos>").ids)`
541
+ # but following is faster since our vocabulary is simple.
542
+ tokens = torch.tensor([self.vocab[token] for token in ["<bos>", "1", *sequence, "2", "<eos>"]])
543
+ position_ids = torch.arange(len(tokens))
544
+
545
+ if self.data_prep_config.remove_X_tokens:
546
+ X_positions = torch.where(tokens != self.X_token_id)[0]
547
+ tokens = tokens[X_positions]
548
+ position_ids = position_ids[X_positions]
549
+
550
+ return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids}
551
+
552
+ def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
553
+ return torch.isin(tokens, self.boundary_token_ids)
554
+
555
+ def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor:
556
+ return tokens == self.mask_token_id
557
+
558
+ def validate_sequence(self, sequence: str) -> bool:
559
+ assert isinstance(sequence, str), "Sequence must be a string"
560
+ sequence = sequence.replace(self.mask_token, "")
561
+ return sequence.isalpha() and sequence.isupper()
562
+
563
+
564
+
565
+ class E1Config(PretrainedConfig):
566
+ model_type = "E1"
567
+ keys_to_ignore_at_inference = ["past_key_values"]
568
+
569
+ def __init__( # type: ignore
570
+ self,
571
+ # Model architecture/initialization
572
+ vocab_size=None,
573
+ hidden_size=4096,
574
+ intermediate_size=16384,
575
+ gated_mlp=False,
576
+ num_hidden_layers=40,
577
+ num_attention_heads=32,
578
+ num_key_value_heads=8,
579
+ hidden_act="silu",
580
+ rms_norm_eps=1e-5,
581
+ initializer_range=0.02,
582
+ torch_dtype="bfloat16",
583
+ gradient_checkpointing=False,
584
+ no_ffn_gradient_checkpointing=False,
585
+ # Tokenization
586
+ pad_token_id=None,
587
+ bos_token_id=None,
588
+ eos_token_id=None,
589
+ tie_word_embeddings=False,
590
+ # Attention implementation & rotary positional embeddings
591
+ global_attention_every_n_layers=0,
592
+ max_num_sequences=512,
593
+ max_num_positions_within_seq=8192,
594
+ max_num_positions_global=1024 * 128,
595
+ rope_theta_within_seq=10000.0,
596
+ rope_theta_global=100000.0,
597
+ clip_qkv=None,
598
+ **kwargs,
599
+ ) -> None:
600
+ tokenizer = get_tokenizer()
601
+ super().__init__(
602
+ pad_token_id=tokenizer.token_to_id("<pad>"),
603
+ bos_token_id=tokenizer.token_to_id("<bos>"),
604
+ eos_token_id=tokenizer.token_to_id("<eos>"),
605
+ tie_word_embeddings=tie_word_embeddings,
606
+ torch_dtype=torch_dtype,
607
+ **kwargs,
608
+ )
609
+
610
+ self.hidden_size = hidden_size
611
+ if intermediate_size is None:
612
+ intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size
613
+ self.intermediate_size = intermediate_size
614
+ self.gated_mlp = gated_mlp
615
+ self.num_hidden_layers = num_hidden_layers
616
+ self.num_attention_heads = num_attention_heads
617
+ self.max_num_positions_within_seq = max_num_positions_within_seq
618
+ self.max_num_positions_global = max_num_positions_global
619
+
620
+ # for backward compatibility
621
+ if num_key_value_heads is None:
622
+ num_key_value_heads = num_attention_heads
623
+
624
+ self.num_key_value_heads = num_key_value_heads
625
+ self.hidden_act = hidden_act
626
+ self.initializer_range = initializer_range
627
+ self.rms_norm_eps = rms_norm_eps
628
+ self.rope_theta_within_seq = rope_theta_within_seq
629
+ self.rope_theta_global = rope_theta_global
630
+ self.max_num_sequences = max_num_sequences
631
+ assert clip_qkv is None or clip_qkv > 0
632
+ self.clip_qkv = clip_qkv
633
+ self.global_attention_every_n_layers = global_attention_every_n_layers
634
+
635
+ self.vocab_size = tokenizer.get_vocab_size()
636
+ self.gradient_checkpointing = gradient_checkpointing
637
+ self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing
638
+
639
+ if vocab_size is not None:
640
+ if vocab_size < self.vocab_size:
641
+ logger.warning(
642
+ f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL."
643
+ )
644
+ self.vocab_size = vocab_size
645
+ elif vocab_size > self.vocab_size:
646
+ logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.")
647
+ self.vocab_size = vocab_size
648
+ if pad_token_id is not None and pad_token_id != self.pad_token_id:
649
+ logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer")
650
+ if bos_token_id is not None and bos_token_id != self.bos_token_id:
651
+ logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer")
652
+ if eos_token_id is not None and eos_token_id != self.eos_token_id:
653
+ logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer")
654
+
655
+
656
+ class DynamicCache:
657
+ """
658
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
659
+ It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`.
660
+
661
+ Args:
662
+ key_cache (`list[torch.Tensor]`): The list of key states.
663
+ value_cache (`list[torch.Tensor]`): The list of value states.
664
+ """
665
+
666
+ def __init__(self) -> None:
667
+ self.key_cache: list[torch.Tensor] = []
668
+ self.value_cache: list[torch.Tensor] = []
669
+
670
+ def update(
671
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
672
+ ) -> tuple[torch.Tensor, torch.Tensor]:
673
+ """
674
+ Update the key and value caches in-place, and return the necessary keys and value states.
675
+
676
+ Args:
677
+ key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim]
678
+ value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim]
679
+ layer_idx (`int`): The index of the layer to update.
680
+
681
+ Returns:
682
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim].
683
+ """
684
+ # Lazy initialization
685
+ if len(self.key_cache) <= layer_idx:
686
+ # There may be skipped layers, fill them with empty lists
687
+ for _ in range(len(self.key_cache), layer_idx):
688
+ self.key_cache.append(torch.tensor([]))
689
+ self.value_cache.append(torch.tensor([]))
690
+ self.key_cache.append(key_states)
691
+ self.value_cache.append(value_states)
692
+ elif (
693
+ not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
694
+ ): # fills previously skipped layers; checking for tensor causes errors
695
+ self.key_cache[layer_idx] = key_states
696
+ self.value_cache[layer_idx] = value_states
697
+ else:
698
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1)
699
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1)
700
+
701
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
702
+
703
+ def get_seq_length(self, layer_idx: int = 0) -> int:
704
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
705
+ is_empty_layer = (
706
+ len(self.key_cache) == 0 # no cache in any layer
707
+ or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
708
+ or not self.key_cache[layer_idx].numel() # the layer has no cache
709
+ )
710
+ layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0
711
+ return layer_seq_length
712
+
713
+ def crop(self, max_length: int) -> None:
714
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
715
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
716
+ assert max_length > 0, "max_length must be positive"
717
+
718
+ if self.get_seq_length() <= max_length:
719
+ return
720
+
721
+ for layer_idx in range(len(self.key_cache)):
722
+ if self.key_cache[layer_idx].numel():
723
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...]
724
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...]
725
+
726
+ def batch_repeat_interleave(self, repeats: int) -> None:
727
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
728
+ for layer_idx in range(len(self.key_cache)):
729
+ if self.key_cache[layer_idx].numel():
730
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
731
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
732
+
733
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
734
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
735
+ for layer_idx in range(len(self.key_cache)):
736
+ if self.key_cache[layer_idx].numel():
737
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
738
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
739
+
740
+
741
+ class KVCache:
742
+ def __init__(self, cache_size: int = 4) -> None:
743
+ self.cache_size = cache_size
744
+ self.tensor_input_field_names = [
745
+ "input_ids",
746
+ "within_seq_position_ids",
747
+ "global_position_ids",
748
+ "sequence_ids",
749
+ "labels",
750
+ ]
751
+ self.tensor_output_field_names = ["logits", "embeddings"]
752
+ self.cache_dict: dict[str, DynamicCache] = {}
753
+ self.cache_queue: list[str] = []
754
+
755
+ def reset(self) -> None:
756
+ for k in list(self.cache_dict.keys()):
757
+ del self.cache_dict[k]
758
+ del self.cache_dict
759
+ self.cache_dict = {}
760
+ self.cache_queue = []
761
+
762
+ torch.cuda.empty_cache()
763
+
764
+ def before_forward(self, batch: dict[str, torch.Tensor]) -> None:
765
+ contexts: list[str] | None = batch.get("context", None)
766
+ if contexts is None or "context_len" not in batch:
767
+ logger.warning_once(
768
+ "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping."
769
+ )
770
+ return
771
+
772
+ context_lens: list[int] = list(set(batch["context_len"]))
773
+ contexts: list[str] = list(set(contexts)) # type: ignore[no-redef]
774
+ if len(contexts) != 1 or len(context_lens) != 1:
775
+ logger.warning(
776
+ "SingleContextKVCache requires a single context and context length. "
777
+ "Multiple contexts or context lengths found in a single batch. Skipping."
778
+ )
779
+ return
780
+
781
+ batch_size = batch["input_ids"].shape[0]
782
+
783
+ unique_context = contexts[0]
784
+ unique_context_len = context_lens[0]
785
+ batch["use_cache"] = True
786
+
787
+ if unique_context not in self.cache_dict:
788
+ return
789
+
790
+ self.cache_dict[unique_context].batch_repeat_interleave(batch_size)
791
+ past_key_values = self.cache_dict[unique_context]
792
+ batch["past_key_values"] = past_key_values
793
+
794
+ # Remove context from the input fields
795
+ for field_name in self.tensor_input_field_names:
796
+ if batch.get(field_name, None) is not None:
797
+ batch[field_name] = batch[field_name][:, unique_context_len:]
798
+
799
+ def after_forward(self, batch: dict[str, Any], outputs: ModelOutput) -> None:
800
+ contexts = batch.get("context", None)
801
+ context_lens = batch.get("context_len", [])
802
+ if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0:
803
+ return
804
+
805
+ assert batch["use_cache"]
806
+ unique_context = contexts[0]
807
+ unique_context_len = context_lens[0]
808
+
809
+ past_key_values = getattr(outputs, "past_key_values", None)
810
+ if not isinstance(past_key_values, DynamicCache):
811
+ logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.")
812
+ return
813
+
814
+ if "past_key_values" not in batch:
815
+ if len(self.cache_queue) == self.cache_size:
816
+ last_context = self.cache_queue.pop(0)
817
+ if last_context not in self.cache_queue:
818
+ del self.cache_dict[last_context]
819
+ torch.cuda.empty_cache()
820
+
821
+ self.cache_dict[unique_context] = past_key_values
822
+ self.cache_queue.append(unique_context)
823
+
824
+ # Remove context from the input fields
825
+ for field_name in self.tensor_input_field_names:
826
+ if field_name in batch and batch[field_name] is not None:
827
+ batch[field_name] = batch[field_name][:, unique_context_len:]
828
+
829
+ # Remove context from the output fields
830
+ for field_name in self.tensor_output_field_names:
831
+ if field_name in outputs and outputs[field_name] is not None:
832
+ outputs[field_name] = outputs[field_name][:, unique_context_len:]
833
+ if "hidden_states" in outputs and outputs["hidden_states"] is not None:
834
+ outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]]
835
+
836
+ self.cache_dict[unique_context].crop(unique_context_len)
837
+ self.cache_dict[unique_context].batch_select_indices([0])
838
+
839
+
840
+ class AttentionMethod(Enum):
841
+ FLASH = "flash"
842
+ FLEX = "flex"
843
+
844
+
845
+ class AttentionLayerType(Enum):
846
+ WITHIN_SEQ = "within_seq"
847
+ GLOBAL = "global"
848
+
849
+
850
+ class AttentionArgs(TypedDict, total=False):
851
+ flex_attention_args: FlexAttentionArgs
852
+
853
+
854
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
855
+ """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
856
+
857
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch,
858
+ num_attention_heads, seqlen, head_dim)
859
+ """
860
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
861
+ if n_rep == 1:
862
+ return hidden_states
863
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
864
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
865
+
866
+
867
+ class RotaryPositionalEmbedding(nn.Module):
868
+ def __init__(
869
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: torch.device | None = None
870
+ ):
871
+ super().__init__()
872
+
873
+ self.dim = dim
874
+ self.base = base
875
+ self.max_position_embeddings = max_position_embeddings
876
+ inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
877
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
878
+
879
+ # Build here to make `torch.jit.trace` work.
880
+ self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
881
+
882
+ @staticmethod
883
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
884
+ """Rotates half the hidden dims of the input."""
885
+ x1 = x[..., : x.shape[-1] // 2]
886
+ x2 = x[..., x.shape[-1] // 2 :]
887
+ return torch.cat((-x2, x1), dim=-1)
888
+
889
+ def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None:
890
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
891
+ self.max_seq_len_cached = seq_len
892
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
893
+ angles = torch.outer(t, self.inv_freq.to(device))
894
+ angles = torch.cat((angles, angles), dim=1)
895
+ self.register_buffer("cos_cached", angles.cos(), persistent=False)
896
+ self.register_buffer("sin_cached", angles.sin(), persistent=False)
897
+
898
+ def forward(
899
+ self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: int | None = None
900
+ ) -> tuple[torch.Tensor, torch.Tensor]:
901
+ # x: [bsz, seq_len, num_attention_heads, head_size]
902
+ device, dtype = q.device, q.dtype
903
+ seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len
904
+
905
+ if seq_len > self.max_seq_len_cached:
906
+ self._set_sin_cos_cache(seq_len=seq_len, device=device)
907
+
908
+ # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim),
909
+ # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim).
910
+ idxs = position_ids.to(device)
911
+ cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
912
+ sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs]
913
+
914
+ # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is
915
+ # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is
916
+ # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same.
917
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
918
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
919
+ return q_embed, k_embed
920
+
921
+
922
+ class Attention(nn.Module):
923
+ """Multi-headed attention from 'Attention Is All You Need' paper."""
924
+
925
+ def __init__(self, config: E1Config, layer_idx: int):
926
+ super().__init__()
927
+ self.config = config
928
+ self.layer_idx = layer_idx
929
+
930
+ self.hidden_size = config.hidden_size
931
+ self.num_heads = config.num_attention_heads
932
+ self.head_dim = self.hidden_size // self.num_heads
933
+ self.num_kv_heads = config.num_key_value_heads
934
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
935
+ self.max_num_seqs = config.max_num_sequences
936
+ self.clip_qkv = config.clip_qkv
937
+
938
+ if (self.head_dim * self.num_heads) != self.hidden_size:
939
+ raise ValueError(
940
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
941
+ f" and `num_heads`: {self.num_heads})."
942
+ )
943
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
944
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
945
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
946
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
947
+
948
+ if self.config.global_attention_every_n_layers > 0:
949
+ self.layer_type = (
950
+ AttentionLayerType.GLOBAL
951
+ if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0
952
+ else AttentionLayerType.WITHIN_SEQ
953
+ )
954
+ else:
955
+ self.layer_type = AttentionLayerType.WITHIN_SEQ
956
+
957
+ self.rope_theta = (
958
+ config.rope_theta_within_seq
959
+ if self.layer_type == AttentionLayerType.WITHIN_SEQ
960
+ else config.rope_theta_global
961
+ )
962
+ self.max_position_embeddings = (
963
+ config.max_num_positions_within_seq
964
+ if self.layer_type == AttentionLayerType.WITHIN_SEQ
965
+ else config.max_num_positions_global
966
+ )
967
+
968
+ self.rotary_emb = RotaryPositionalEmbedding(
969
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
970
+ )
971
+
972
+ def prepare_qkv(
973
+ self,
974
+ hidden_states: torch.Tensor,
975
+ position_ids: torch.LongTensor,
976
+ past_key_value: DynamicCache | None = None,
977
+ use_cache: bool = False,
978
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
979
+ bsz, q_len, _ = hidden_states.size()
980
+ query_states: torch.Tensor = self.q_proj(hidden_states)
981
+ key_states: torch.Tensor = self.k_proj(hidden_states)
982
+ val_states: torch.Tensor = self.v_proj(hidden_states)
983
+
984
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
985
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
986
+ val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
987
+
988
+ if self.clip_qkv is not None:
989
+ query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv)
990
+ key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv)
991
+ val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv)
992
+
993
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
994
+
995
+ if use_cache and past_key_value is not None:
996
+ key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx)
997
+
998
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
999
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
1000
+ # cast them back in float16 just to be sure everything works as expected.
1001
+ input_dtype = query_states.dtype
1002
+ if torch.is_autocast_enabled():
1003
+ target_dtype = torch.get_autocast_gpu_dtype()
1004
+ else:
1005
+ target_dtype = self.q_proj.weight.dtype
1006
+ if input_dtype != target_dtype:
1007
+ logger.warning_once(
1008
+ f"The input hidden states seems to be silently casted in {input_dtype}. "
1009
+ f"This might be because you have upcasted embedding or layer norm layers "
1010
+ f"in {input_dtype}. We will cast back the input in {target_dtype}."
1011
+ )
1012
+ query_states = query_states.to(target_dtype)
1013
+ key_states = key_states.to(target_dtype)
1014
+ val_states = val_states.to(target_dtype)
1015
+
1016
+ return query_states, key_states, val_states
1017
+
1018
+ def forward(
1019
+ self,
1020
+ hidden_states: torch.Tensor,
1021
+ within_seq_position_ids: torch.LongTensor,
1022
+ global_position_ids: torch.LongTensor,
1023
+ sequence_ids: torch.LongTensor,
1024
+ attention_args: AttentionArgs | None = None,
1025
+ past_key_value: DynamicCache | None = None,
1026
+ output_attentions: bool = False,
1027
+ use_cache: bool = False,
1028
+ ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]:
1029
+ is_cache_prefilled = (
1030
+ use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0
1031
+ )
1032
+
1033
+ query_states, key_states, val_states = self.prepare_qkv(
1034
+ hidden_states=hidden_states,
1035
+ position_ids=within_seq_position_ids
1036
+ if self.layer_type == AttentionLayerType.WITHIN_SEQ
1037
+ else global_position_ids,
1038
+ past_key_value=past_key_value,
1039
+ use_cache=use_cache,
1040
+ )
1041
+
1042
+ # Note: We fallback to using flash attention in inference mode when cache is filled with kv values
1043
+ # for global attention layers instead of flex attention. This is because once the cache is filled,
1044
+ # the last sequence attends to everything in the cache, so we can make things faster by using a
1045
+ # bidirectional flash attention instead of block-causal flex attention.
1046
+ if self.layer_type == AttentionLayerType.WITHIN_SEQ or is_cache_prefilled:
1047
+ attention_type = AttentionMethod.FLASH
1048
+ else:
1049
+ attention_type = AttentionMethod.FLEX
1050
+
1051
+ attn_output, attn_weights = self._attn(
1052
+ attention_type=attention_type,
1053
+ query_states=query_states,
1054
+ key_states=key_states,
1055
+ val_states=val_states,
1056
+ sequence_ids=sequence_ids,
1057
+ attention_args=attention_args,
1058
+ output_attentions=output_attentions,
1059
+ )
1060
+
1061
+ attn_output = self.o_proj(attn_output)
1062
+ return attn_output, attn_weights, past_key_value
1063
+
1064
+ def _attn(
1065
+ self,
1066
+ attention_type: AttentionMethod,
1067
+ query_states: torch.Tensor,
1068
+ key_states: torch.Tensor,
1069
+ val_states: torch.Tensor,
1070
+ sequence_ids: torch.Tensor,
1071
+ attention_args: AttentionArgs | None = None,
1072
+ output_attentions: bool = False,
1073
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
1074
+ match attention_type:
1075
+ case AttentionMethod.FLASH:
1076
+ f = self._flash_attn
1077
+ case AttentionMethod.FLEX:
1078
+ f = self._flex_attn
1079
+ case _:
1080
+ raise ValueError(f"No attention implementation found for {attention_type}")
1081
+ return f(
1082
+ query_states=query_states,
1083
+ key_states=key_states,
1084
+ val_states=val_states,
1085
+ sequence_ids=sequence_ids,
1086
+ attention_args=attention_args,
1087
+ output_attentions=output_attentions,
1088
+ )
1089
+
1090
+ def _flash_attn(
1091
+ self,
1092
+ query_states: torch.Tensor,
1093
+ key_states: torch.Tensor,
1094
+ val_states: torch.Tensor,
1095
+ sequence_ids: torch.Tensor,
1096
+ attention_args: AttentionArgs | None = None,
1097
+ output_attentions: bool = False,
1098
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
1099
+ """Flash attention implementation.
1100
+
1101
+ Calls the public API of flash attention and deals with padding tokens if any are present.
1102
+ """
1103
+ assert not output_attentions, "Flash attention doesn't support returning attention masks"
1104
+ bsz, q_len = query_states.shape[0], query_states.shape[1]
1105
+ _, kv_len = key_states.shape[0], key_states.shape[1]
1106
+
1107
+ if self.layer_type == AttentionLayerType.GLOBAL: # Only happens in inference
1108
+ q_sequence_ids = sequence_ids
1109
+ if q_len < kv_len:
1110
+ # Assumes query contain only one sequence
1111
+ # and all tokens in query (except padding) will attend to all tokens in KV
1112
+ first_token_id = sequence_ids[:, 0].unsqueeze(1)
1113
+ k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1)
1114
+ else:
1115
+ k_sequence_ids = sequence_ids
1116
+ else:
1117
+ if q_len < kv_len: # Only happens in inference
1118
+ key_states = key_states[:, -q_len:]
1119
+ val_states = val_states[:, -q_len:]
1120
+ q_sequence_ids = k_sequence_ids = sequence_ids
1121
+
1122
+ if is_flash_attention_available():
1123
+ attn_output = flash_attention_func(
1124
+ query_states,
1125
+ key_states,
1126
+ val_states,
1127
+ q_sequence_ids=q_sequence_ids,
1128
+ k_sequence_ids=k_sequence_ids,
1129
+ causal=False,
1130
+ )
1131
+ else:
1132
+ attn_output = varlen_flex_attention_func(
1133
+ query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids
1134
+ )
1135
+
1136
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1137
+ return attn_output, None
1138
+
1139
+ def _flex_attn(
1140
+ self,
1141
+ query_states: torch.Tensor,
1142
+ key_states: torch.Tensor,
1143
+ val_states: torch.Tensor,
1144
+ sequence_ids: torch.Tensor,
1145
+ attention_args: AttentionArgs | None = None,
1146
+ output_attentions: bool = False,
1147
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
1148
+ bsz, q_len = query_states.shape[0], query_states.shape[1]
1149
+ flex_attention_args = attention_args.get("flex_attention_args", None) if attention_args is not None else None
1150
+ block_mask = flex_attention_args.get("block_mask", None) if flex_attention_args is not None else None
1151
+ score_mod = flex_attention_args.get("score_mod", None) if flex_attention_args is not None else None
1152
+ outputs = flex_attention_func(query_states, key_states, val_states, score_mod=score_mod, block_mask=block_mask)
1153
+
1154
+ outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous()
1155
+ return outputs, None
1156
+
1157
+
1158
+ class MLP(nn.Module):
1159
+ def __init__(self, config: E1Config):
1160
+ super().__init__()
1161
+ self.ffn_dim = config.intermediate_size
1162
+ self.hidden_dim = config.hidden_size
1163
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
1164
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
1165
+ self.act_fn = ACT2FN[config.hidden_act]
1166
+
1167
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1168
+ return self.w2(self.act_fn(self.w1(hidden_states)))
1169
+
1170
+
1171
+ class GLUMLP(nn.Module):
1172
+ def __init__(self, config: E1Config):
1173
+ super().__init__()
1174
+ self.ffn_dim = config.intermediate_size
1175
+ self.hidden_dim = config.hidden_size
1176
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
1177
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
1178
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
1179
+ self.act_fn = ACT2FN[config.hidden_act]
1180
+
1181
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1182
+ hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
1183
+ hidden_states = self.w2(hidden_states)
1184
+ return hidden_states
1185
+
1186
+
1187
+ class FFN(nn.Module):
1188
+ def __init__(self, config: E1Config):
1189
+ super().__init__()
1190
+ mlp_cls = GLUMLP if config.gated_mlp else MLP
1191
+ self.mlp = mlp_cls(config)
1192
+
1193
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1194
+ return self.mlp(hidden_states)
1195
+
1196
+
1197
+ @dataclass
1198
+ class E1ModelOutputWithPast(ModelOutput):
1199
+ """Base class for model's outputs, with potential hidden states and attentions.
1200
+
1201
+ Attributes:
1202
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1203
+ Sequence of hidden-states at the output of the last layer of the model.
1204
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1205
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1206
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
1207
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
1208
+ encoder_sequence_length, embed_size_per_head)`.
1209
+
1210
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
1211
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
1212
+ input) to speed up sequential decoding.
1213
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1214
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1215
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1216
+
1217
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1218
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1219
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1220
+ sequence_length)`.
1221
+
1222
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1223
+ heads.
1224
+ """
1225
+
1226
+ last_hidden_state: torch.FloatTensor | None = None
1227
+ past_key_values: DynamicCache | None = None
1228
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
1229
+ attentions: tuple[torch.FloatTensor, ...] | None = None
1230
+
1231
+
1232
+ @dataclass
1233
+ class E1MaskedLMOutputWithPast(ModelOutput):
1234
+ loss: torch.FloatTensor | None = None
1235
+ mlm_loss: torch.FloatTensor | None = None
1236
+ logits: torch.FloatTensor | None = None
1237
+ last_hidden_state: torch.FloatTensor | None = None
1238
+ past_key_values: DynamicCache | None = None
1239
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
1240
+ attentions: tuple[torch.FloatTensor, ...] | None = None
1241
+
1242
+
1243
+ @dataclass
1244
+ class E1ClassificationOutputWithPast(ModelOutput):
1245
+ loss: torch.FloatTensor | None = None
1246
+ logits: torch.FloatTensor | None = None
1247
+ last_hidden_state: torch.FloatTensor | None = None
1248
+ past_key_values: DynamicCache | None = None
1249
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
1250
+ attentions: tuple[torch.FloatTensor, ...] | None = None
1251
+
1252
+
1253
+ class RMSNorm(nn.Module):
1254
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
1255
+ super().__init__()
1256
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1257
+ self.variance_epsilon = eps
1258
+ self.hidden_size = hidden_size
1259
+
1260
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1261
+ input_dtype = hidden_states.dtype
1262
+ if layer_norm is None:
1263
+ return torch.nn.functional.rms_norm(
1264
+ hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon
1265
+ ).to(input_dtype)
1266
+ else:
1267
+ return layer_norm.rms_norm_fn(
1268
+ x=hidden_states,
1269
+ weight=self.weight,
1270
+ bias=None, # no bias
1271
+ residual=None,
1272
+ eps=self.variance_epsilon,
1273
+ dropout_p=0.0, # no dropout by default
1274
+ prenorm=False,
1275
+ residual_in_fp32=False,
1276
+ ).to(input_dtype)
1277
+
1278
+
1279
+ class NormAttentionNorm(nn.Module):
1280
+ def __init__(self, config: E1Config, layer_idx: int):
1281
+ super().__init__()
1282
+ self.self_attn = Attention(config, layer_idx)
1283
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1284
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1285
+
1286
+ def forward(
1287
+ self,
1288
+ hidden_states: torch.Tensor,
1289
+ within_seq_position_ids: torch.LongTensor,
1290
+ global_position_ids: torch.LongTensor,
1291
+ sequence_ids: torch.LongTensor,
1292
+ attention_args: AttentionArgs | None = None,
1293
+ past_key_value: DynamicCache | None = None,
1294
+ output_attentions: bool = False,
1295
+ use_cache: bool = False,
1296
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, DynamicCache | None]:
1297
+ residual = hidden_states
1298
+ hidden_states = self.input_layernorm(hidden_states)
1299
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1300
+ hidden_states=hidden_states,
1301
+ within_seq_position_ids=within_seq_position_ids,
1302
+ global_position_ids=global_position_ids,
1303
+ sequence_ids=sequence_ids,
1304
+ attention_args=attention_args,
1305
+ past_key_value=past_key_value,
1306
+ output_attentions=output_attentions,
1307
+ use_cache=use_cache,
1308
+ )
1309
+ hidden_states = residual + hidden_states
1310
+
1311
+ residual = hidden_states
1312
+ hidden_states = self.post_attention_layernorm(hidden_states)
1313
+ return hidden_states, residual, self_attn_weights, present_key_value
1314
+
1315
+
1316
+ class DecoderLayer(nn.Module):
1317
+ def __init__(self, config: E1Config, layer_idx: int):
1318
+ super().__init__()
1319
+ self.initializer_range = config.initializer_range
1320
+ self.hidden_size = config.hidden_size
1321
+ self.norm_attn_norm = NormAttentionNorm(config, layer_idx)
1322
+ self.ffn = FFN(config)
1323
+
1324
+ def forward(
1325
+ self,
1326
+ hidden_states: torch.Tensor,
1327
+ within_seq_position_ids: torch.LongTensor,
1328
+ global_position_ids: torch.LongTensor,
1329
+ sequence_ids: torch.LongTensor,
1330
+ attention_args: AttentionArgs | None = None,
1331
+ past_key_value: DynamicCache | None = None,
1332
+ output_attentions: bool = False,
1333
+ use_cache: bool = False,
1334
+ ) -> tuple[torch.Tensor, torch.Tensor | None, DynamicCache | None]:
1335
+ hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm(
1336
+ hidden_states=hidden_states,
1337
+ within_seq_position_ids=within_seq_position_ids,
1338
+ global_position_ids=global_position_ids,
1339
+ sequence_ids=sequence_ids,
1340
+ attention_args=attention_args,
1341
+ past_key_value=past_key_value,
1342
+ output_attentions=output_attentions,
1343
+ use_cache=use_cache,
1344
+ )
1345
+
1346
+ # Fully Connected
1347
+ hidden_states = self.ffn(hidden_states)
1348
+ hidden_states = residual + hidden_states
1349
+
1350
+ return hidden_states, self_attn_weights, present_key_value
1351
+
1352
+
1353
+ ### Support for embedding datasets with low code
1354
+ class Pooler:
1355
+ def __init__(self, pooling_types: List[str]):
1356
+ self.pooling_types = pooling_types
1357
+ self.pooling_options = {
1358
+ 'mean': self.mean_pooling,
1359
+ 'max': self.max_pooling,
1360
+ 'norm': self.norm_pooling,
1361
+ 'median': self.median_pooling,
1362
+ 'std': self.std_pooling,
1363
+ 'var': self.var_pooling,
1364
+ 'cls': self.cls_pooling,
1365
+ 'parti': self._pool_parti,
1366
+ }
1367
+
1368
+ def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
1369
+ maxed_attentions = torch.max(attentions, dim=1)[0]
1370
+ return maxed_attentions
1371
+
1372
+ def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
1373
+ # Run PageRank on the attention matrix converted to a graph.
1374
+ # Raises exceptions if the graph doesn't match the token sequence or has no edges.
1375
+ # Returns the PageRank scores for each token node.
1376
+ G = self._convert_to_graph(attention_matrix)
1377
+ if G.number_of_nodes() != attention_matrix.shape[0]:
1378
+ raise Exception(
1379
+ f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
1380
+ if G.number_of_edges() == 0:
1381
+ raise Exception(f"You don't seem to have any attention edges left in the graph.")
1382
+
1383
+ return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
1384
+
1385
+ def _convert_to_graph(self, matrix):
1386
+ # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
1387
+ # Each element in the matrix represents a directed edge with a weight.
1388
+ G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
1389
+ return G
1390
+
1391
+ def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
1392
+ # Remove keys where attention_mask is 0
1393
+ if attention_mask is not None:
1394
+ for k in list(dict_importance.keys()):
1395
+ if attention_mask[k] == 0:
1396
+ del dict_importance[k]
1397
+
1398
+ #dict_importance[0] # remove cls
1399
+ #dict_importance[-1] # remove eos
1400
+ total = sum(dict_importance.values())
1401
+ return np.array([v / total for _, v in dict_importance.items()])
1402
+
1403
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
1404
+ maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
1405
+ # emb is (b, L, d), maxed_attentions is (b, L, L)
1406
+ emb_pooled = []
1407
+ for e, a, mask in zip(emb, maxed_attentions, attention_mask):
1408
+ dict_importance = self._page_rank(a)
1409
+ importance_weights = self._calculate_importance_weights(dict_importance, mask)
1410
+ num_tokens = int(mask.sum().item())
1411
+ emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
1412
+ pooled = torch.tensor(np.array(emb_pooled))
1413
+ return pooled
1414
+
1415
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1416
+ if attention_mask is None:
1417
+ return emb.mean(dim=1)
1418
+ else:
1419
+ attention_mask = attention_mask.unsqueeze(-1)
1420
+ return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
1421
+
1422
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1423
+ if attention_mask is None:
1424
+ return emb.max(dim=1).values
1425
+ else:
1426
+ attention_mask = attention_mask.unsqueeze(-1)
1427
+ return (emb * attention_mask).max(dim=1).values
1428
+
1429
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1430
+ if attention_mask is None:
1431
+ return emb.norm(dim=1, p=2)
1432
+ else:
1433
+ attention_mask = attention_mask.unsqueeze(-1)
1434
+ return (emb * attention_mask).norm(dim=1, p=2)
1435
+
1436
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1437
+ if attention_mask is None:
1438
+ return emb.median(dim=1).values
1439
+ else:
1440
+ attention_mask = attention_mask.unsqueeze(-1)
1441
+ return (emb * attention_mask).median(dim=1).values
1442
+
1443
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1444
+ if attention_mask is None:
1445
+ return emb.std(dim=1)
1446
+ else:
1447
+ # Compute variance correctly over non-masked positions, then take sqrt
1448
+ var = self.var_pooling(emb, attention_mask, **kwargs)
1449
+ return torch.sqrt(var)
1450
+
1451
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1452
+ if attention_mask is None:
1453
+ return emb.var(dim=1)
1454
+ else:
1455
+ # Correctly compute variance over only non-masked positions
1456
+ attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
1457
+ # Compute mean over non-masked positions
1458
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
1459
+ mean = mean.unsqueeze(1) # (b, 1, d)
1460
+ # Compute squared differences from mean, only over non-masked positions
1461
+ squared_diff = (emb - mean) ** 2 # (b, L, d)
1462
+ # Sum squared differences over non-masked positions and divide by count
1463
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
1464
+ return var
1465
+
1466
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
1467
+ return emb[:, 0, :]
1468
+
1469
+ def __call__(
1470
+ self,
1471
+ emb: torch.Tensor,
1472
+ attention_mask: Optional[torch.Tensor] = None,
1473
+ attentions: Optional[torch.Tensor] = None
1474
+ ): # [mean, max]
1475
+ final_emb = []
1476
+ for pooling_type in self.pooling_types:
1477
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
1478
+ return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
1479
+
1480
+
1481
+ class EmbeddingMixin:
1482
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1483
+ raise NotImplementedError
1484
+
1485
+ @property
1486
+ def device(self) -> torch.device:
1487
+ """Get the device of the model."""
1488
+ return next(self.parameters()).device
1489
+
1490
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
1491
+ """Read sequences from SQLite database."""
1492
+ import sqlite3
1493
+ sequences = []
1494
+ with sqlite3.connect(db_path) as conn:
1495
+ c = conn.cursor()
1496
+ c.execute("SELECT sequence FROM embeddings")
1497
+ while True:
1498
+ row = c.fetchone()
1499
+ if row is None:
1500
+ break
1501
+ sequences.append(row[0])
1502
+ return set(sequences)
1503
+
1504
+ def embed_dataset(
1505
+ self,
1506
+ sequences: List[str],
1507
+ #tokenizer: PreTrainedTokenizerBase, # For E1, the tokenizing is handled by _embed
1508
+ batch_size: int = 2,
1509
+ max_len: int = 512,
1510
+ truncate: bool = True,
1511
+ full_embeddings: bool = False,
1512
+ embed_dtype: torch.dtype = torch.float32,
1513
+ pooling_types: List[str] = ['mean'],
1514
+ sql: bool = False,
1515
+ save: bool = True,
1516
+ sql_db_path: str = 'embeddings.db',
1517
+ save_path: str = 'embeddings.pth',
1518
+ ) -> Optional[dict[str, torch.Tensor]]:
1519
+ """Embed a dataset of protein sequences.
1520
+
1521
+ Args:
1522
+ sequences: List of protein sequences
1523
+ batch_size: Batch size for processing
1524
+ max_len: Maximum sequence length
1525
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
1526
+ pooling_type: Type of pooling ('mean' or 'cls')
1527
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
1528
+ sql_db_path: Path to SQLite database
1529
+
1530
+ Returns:
1531
+ Dictionary mapping sequences to embeddings, or None if sql=True
1532
+
1533
+ Note:
1534
+ - If sql=True, embeddings can only be stored in float32
1535
+ - sql is ideal if you need to stream a very large dataset for training in real-time
1536
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
1537
+ - sql will be used if it is True and save is True or False
1538
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
1539
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
1540
+
1541
+ Example:
1542
+ >>> embedder = EmbeddingMixin()
1543
+ >>> embedding_dict = embedder.embed_dataset(
1544
+ sequences=[
1545
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
1546
+ ],
1547
+ batch_size=2, # adjust for your GPU memory
1548
+ max_len=512, # adjust for your needs
1549
+ full_embeddings=False, # if True, no pooling is performed
1550
+ embed_dtype=torch.float32, # cast to what dtype you want
1551
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
1552
+ sql=False, # if True, embeddings will be stored in SQLite database
1553
+ sql_db_path='embeddings.db',
1554
+ save=True, # if True, embeddings will be saved as a .pth file
1555
+ save_path='embeddings.pth',
1556
+ )
1557
+ >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
1558
+ """
1559
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
1560
+ sequences = sorted(sequences, key=len, reverse=True)
1561
+ hidden_size = self.config.hidden_size
1562
+ pooler = Pooler(pooling_types) if not full_embeddings else None
1563
+
1564
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1565
+ if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
1566
+ return residue_embeddings
1567
+ else:
1568
+ return pooler(residue_embeddings, attention_mask)
1569
+
1570
+ if sql:
1571
+ import sqlite3
1572
+ conn = sqlite3.connect(sql_db_path)
1573
+ c = conn.cursor()
1574
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
1575
+ already_embedded = self._read_sequences_from_db(sql_db_path)
1576
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
1577
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
1578
+ print(f"Embedding {len(to_embed)} new sequences")
1579
+ if len(to_embed) > 0:
1580
+ with torch.no_grad():
1581
+ for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1582
+ seqs = to_embed[i:i + batch_size]
1583
+ input_ids, attention_mask = self._embed(seqs, return_attention_mask=True).float() # sql requires float32
1584
+ embeddings = get_embeddings(input_ids, attention_mask)
1585
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1586
+ if full_embeddings:
1587
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
1588
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
1589
+ conn.commit()
1590
+ conn.commit()
1591
+ conn.close()
1592
+ return None
1593
+
1594
+ embeddings_dict = {}
1595
+ if os.path.exists(save_path):
1596
+ embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
1597
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
1598
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
1599
+ print(f"Embedding {len(to_embed)} new sequences")
1600
+ else:
1601
+ to_embed = sequences
1602
+ print(f"Embedding {len(to_embed)} new sequences")
1603
+
1604
+ if len(to_embed) > 0:
1605
+ with torch.no_grad():
1606
+ for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1607
+ seqs = to_embed[i:i + batch_size]
1608
+ last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
1609
+ embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
1610
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1611
+ if full_embeddings:
1612
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
1613
+ embeddings_dict[seq] = emb.cpu()
1614
+
1615
+ if save:
1616
+ torch.save(embeddings_dict, save_path)
1617
+
1618
+ return embeddings_dict
1619
+
1620
+
1621
+ class E1PreTrainedModel(PreTrainedModel):
1622
+ config_class = E1Config
1623
+ config: E1Config
1624
+ base_model_prefix = "model"
1625
+ supports_gradient_checkpointing = True
1626
+ _no_split_modules = ["DecoderLayer"]
1627
+ _transformer_layer_cls = [DecoderLayer]
1628
+ _skip_keys_device_placement = "past_key_values"
1629
+
1630
+ def _init_weights(self, module: nn.Module) -> None:
1631
+ std = self.config.initializer_range
1632
+ if isinstance(module, nn.Linear):
1633
+ module.weight.data.normal_(mean=0.0, std=std)
1634
+ if module.bias is not None:
1635
+ module.bias.data.zero_()
1636
+ elif isinstance(module, nn.Embedding):
1637
+ module.weight.data.normal_(mean=0.0, std=std)
1638
+ if module.padding_idx is not None:
1639
+ module.weight.data[module.padding_idx].zero_()
1640
+ elif isinstance(module, RMSNorm):
1641
+ module.weight.data.fill_(1.0)
1642
+
1643
+ def post_init(self) -> None:
1644
+ super().post_init()
1645
+
1646
+ def _backward_compatibility_gradient_checkpointing(self) -> None:
1647
+ if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
1648
+ self.gradient_checkpointing_enable(dict(use_reentrant=False))
1649
+
1650
+ @property
1651
+ def _device(self) -> torch.device:
1652
+ return next(self.parameters()).device
1653
+
1654
+ @classmethod
1655
+ def from_pretrained( # type: ignore[no-untyped-def]
1656
+ cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs
1657
+ ) -> "E1PreTrainedModel":
1658
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
1659
+
1660
+
1661
+ class E1Model(E1PreTrainedModel, EmbeddingMixin):
1662
+ config: E1Config
1663
+ config_class = E1Config
1664
+ def __init__(self, config: E1Config, **kwargs):
1665
+ E1PreTrainedModel.__init__(self, config, **kwargs)
1666
+ self.padding_idx = config.pad_token_id
1667
+ self.vocab_size = config.vocab_size
1668
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1669
+ self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size)
1670
+ self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
1671
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1672
+ self.gradient_checkpointing = config.gradient_checkpointing
1673
+ self.prep_tokens = E1BatchPreparer()
1674
+ self.post_init()
1675
+
1676
+ def get_input_embeddings(self) -> nn.Embedding:
1677
+ return self.embed_tokens
1678
+
1679
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
1680
+ self.embed_tokens = value
1681
+
1682
+ @torch.inference_mode()
1683
+ def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
1684
+ batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
1685
+ last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
1686
+ if return_attention_mask:
1687
+ attention_mask = (batch['sequence_ids'] != -1).long()
1688
+ return last_hidden_state, attention_mask
1689
+ else:
1690
+ return last_hidden_state
1691
+
1692
+ # Ignore copy
1693
+ def forward(
1694
+ self,
1695
+ input_ids: torch.LongTensor,
1696
+ within_seq_position_ids: torch.LongTensor,
1697
+ global_position_ids: torch.LongTensor,
1698
+ sequence_ids: torch.LongTensor,
1699
+ past_key_values: DynamicCache | None = None,
1700
+ use_cache: bool = False,
1701
+ output_attentions: bool = False,
1702
+ output_hidden_states: bool = False,
1703
+ **kwargs
1704
+ ) -> E1ModelOutputWithPast:
1705
+ """
1706
+ Args:
1707
+ input_ids: (batch_size, seq_length)
1708
+ within_seq_position_ids: (batch_size, seq_length)
1709
+ This tensor contains the position of each residue within the sequence itself.
1710
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
1711
+ the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
1712
+ global_position_ids: (batch_size, seq_length)
1713
+ This tensor contains the position of each residue within the global sequence.
1714
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
1715
+ the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
1716
+ sequence_ids: (batch_size, seq_length)
1717
+ This tensor contains the sequence id of each residue.
1718
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
1719
+ the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
1720
+ past_key_values: DynamicCache
1721
+ use_cache: bool
1722
+ output_attentions: bool
1723
+ output_hidden_states: bool
1724
+
1725
+ Returns:
1726
+ E1ModelOutputWithPast: Model Outputs
1727
+ """
1728
+ batch_size, seq_length = input_ids.shape
1729
+
1730
+ if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
1731
+ if use_cache:
1732
+ logger.warning_once(
1733
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1734
+ )
1735
+ use_cache = False
1736
+
1737
+ if use_cache and past_key_values is None:
1738
+ past_key_values = DynamicCache()
1739
+ elif not use_cache:
1740
+ # To avoid weirdness with gradient checkpointing: https://github.com/huggingface/transformers/issues/28499
1741
+ past_key_values = None
1742
+
1743
+ global_position_ids = global_position_ids.view(-1, seq_length).long()
1744
+ within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long()
1745
+ sequence_ids = sequence_ids.view(-1, seq_length).long()
1746
+
1747
+ max_position_id = torch.max(within_seq_position_ids).item()
1748
+ min_position_id = torch.min(within_seq_position_ids).item()
1749
+ assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, (
1750
+ f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}"
1751
+ )
1752
+
1753
+ inputs_embeds = self.embed_tokens(input_ids)
1754
+ # -1 is used to indicate padding tokens, so we need to clamp the sequence ids to 0
1755
+ inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0))
1756
+
1757
+ # In case we need to do any manual typecasting
1758
+ if torch.is_autocast_enabled():
1759
+ target_dtype = torch.get_autocast_gpu_dtype()
1760
+ else:
1761
+ target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype
1762
+ hidden_states = inputs_embeds.to(target_dtype)
1763
+
1764
+ # (batch_size, query_length, keyval_length)
1765
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
1766
+
1767
+ # Create block mask for flex attention
1768
+ attention_args: AttentionArgs | None = None
1769
+ if past_key_values_length == 0:
1770
+ block_mask = create_block_causal_mask_optimized(sequence_ids)
1771
+ flex_attention_args = FlexAttentionArgs(block_mask=block_mask)
1772
+ attention_args = AttentionArgs(flex_attention_args=flex_attention_args)
1773
+
1774
+ # decoder layers
1775
+ all_hidden_states = () if output_hidden_states else None
1776
+ all_self_attns = () if output_attentions else None
1777
+ next_decoder_cache = None
1778
+
1779
+ for decoder_layer in self.layers:
1780
+ if output_hidden_states:
1781
+ all_hidden_states += (hidden_states,) # type: ignore[operator]
1782
+
1783
+ if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
1784
+ layer_outputs = self._gradient_checkpointing_func(
1785
+ decoder_layer.__call__,
1786
+ hidden_states,
1787
+ within_seq_position_ids,
1788
+ global_position_ids,
1789
+ sequence_ids,
1790
+ attention_args,
1791
+ past_key_values,
1792
+ output_attentions,
1793
+ use_cache,
1794
+ )
1795
+ else:
1796
+ layer_outputs = decoder_layer(
1797
+ hidden_states,
1798
+ within_seq_position_ids=within_seq_position_ids,
1799
+ global_position_ids=global_position_ids,
1800
+ sequence_ids=sequence_ids,
1801
+ attention_args=attention_args,
1802
+ past_key_value=past_key_values,
1803
+ output_attentions=output_attentions,
1804
+ use_cache=use_cache,
1805
+ )
1806
+
1807
+ hidden_states, self_attn_weights, present_key_value = layer_outputs
1808
+
1809
+ if use_cache:
1810
+ # NOTE: it's necessary to re-assign past_key_values because FSDP2
1811
+ # passes certain arguments by value, not by reference.
1812
+ # See https://github.com/huggingface/transformers/issues/38190#issuecomment-2914016168
1813
+ next_decoder_cache = past_key_values = present_key_value
1814
+
1815
+ if output_attentions:
1816
+ all_self_attns += (self_attn_weights,) # type: ignore[operator]
1817
+
1818
+ hidden_states = self.norm(hidden_states)
1819
+
1820
+ # add hidden states from the last decoder layer
1821
+ if output_hidden_states:
1822
+ all_hidden_states += (hidden_states,) # type: ignore[operator]
1823
+
1824
+ next_cache = next_decoder_cache if use_cache else None
1825
+
1826
+ return E1ModelOutputWithPast(
1827
+ last_hidden_state=hidden_states,
1828
+ past_key_values=next_cache,
1829
+ hidden_states=all_hidden_states,
1830
+ attentions=all_self_attns,
1831
+ )
1832
+
1833
+
1834
+ class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin):
1835
+ config: E1Config
1836
+ config_class = E1Config
1837
+ def __init__(self, config: E1Config, **kwargs):
1838
+ E1PreTrainedModel.__init__(self, config, **kwargs)
1839
+ self.model: E1Model = E1Model(config)
1840
+ self.vocab_size = config.vocab_size
1841
+ self.mlm_head = torch.nn.Sequential(
1842
+ nn.Linear(config.hidden_size, config.hidden_size, bias=True),
1843
+ nn.GELU(),
1844
+ nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps),
1845
+ nn.Linear(config.hidden_size, config.vocab_size, bias=True),
1846
+ )
1847
+ self.gradient_checkpointing = config.gradient_checkpointing
1848
+ self.prep_tokens = E1BatchPreparer()
1849
+ self.post_init()
1850
+
1851
+ @property
1852
+ def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
1853
+ return self.model.device_mesh
1854
+
1855
+ @torch.inference_mode()
1856
+ def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
1857
+ batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
1858
+ last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
1859
+ if return_attention_mask:
1860
+ attention_mask = (batch['sequence_ids'] != -1).long()
1861
+ return last_hidden_state, attention_mask
1862
+ else:
1863
+ return last_hidden_state
1864
+
1865
+ def forward(
1866
+ self,
1867
+ input_ids: torch.LongTensor,
1868
+ within_seq_position_ids: torch.LongTensor,
1869
+ global_position_ids: torch.LongTensor,
1870
+ sequence_ids: torch.LongTensor,
1871
+ labels: torch.LongTensor | None = None,
1872
+ past_key_values: DynamicCache | None = None,
1873
+ use_cache: bool = False,
1874
+ output_attentions: bool = False,
1875
+ output_hidden_states: bool = False,
1876
+ **kwargs,
1877
+ ) -> E1MaskedLMOutputWithPast:
1878
+ """
1879
+ Args:
1880
+ input_ids: (batch_size, seq_length)
1881
+ within_seq_position_ids: (batch_size, seq_length)
1882
+ This tensor contains the position of each residue within the sequence itself.
1883
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos><pad>"],
1884
+ the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]]
1885
+ global_position_ids: (batch_size, seq_length)
1886
+ This tensor contains the position of each residue within the global sequence.
1887
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
1888
+ the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]]
1889
+ sequence_ids: (batch_size, seq_length)
1890
+ This tensor contains the sequence id of each residue.
1891
+ For example, if the input is ["<bos>1ABC2<eos><bos>1DEF2<eos>", "<bos>1GH2<eos><bos>1JKL2<eos>"],
1892
+ the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]]
1893
+ labels: (batch_size, seq_length)
1894
+ past_key_values: DynamicCache
1895
+ use_cache: bool
1896
+ output_attentions: bool
1897
+ output_hidden_states: bool
1898
+
1899
+ Returns:
1900
+ E1MaskedLMOutputWithPast: Model Outputs
1901
+ """
1902
+ outputs: E1ModelOutputWithPast = self.model(
1903
+ input_ids=input_ids,
1904
+ within_seq_position_ids=within_seq_position_ids,
1905
+ global_position_ids=global_position_ids,
1906
+ sequence_ids=sequence_ids,
1907
+ past_key_values=past_key_values,
1908
+ use_cache=use_cache,
1909
+ output_attentions=output_attentions,
1910
+ output_hidden_states=output_hidden_states,
1911
+ )
1912
+
1913
+ x = outputs.last_hidden_state
1914
+ loss = None
1915
+
1916
+ # Compute masked language modeling loss
1917
+ mlm_logits = self.mlm_head(x).float()
1918
+ mlm_loss = 0.0
1919
+ if labels is not None:
1920
+ mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size)
1921
+ mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1)
1922
+ mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none")
1923
+ mask = mlm_labels_flat != self.model.padding_idx
1924
+ n_mlm = mask.sum()
1925
+ mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm)
1926
+ loss = 0.0
1927
+ loss += mlm_loss
1928
+
1929
+ return E1MaskedLMOutputWithPast(
1930
+ loss=loss,
1931
+ mlm_loss=mlm_loss,
1932
+ logits=mlm_logits,
1933
+ last_hidden_state=x,
1934
+ past_key_values=outputs.past_key_values,
1935
+ hidden_states=outputs.hidden_states,
1936
+ attentions=outputs.attentions,
1937
+ )
1938
+
1939
+
1940
+ class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin):
1941
+ config: E1Config
1942
+ config_class = E1Config
1943
+ def __init__(self, config: E1Config, **kwargs):
1944
+ E1PreTrainedModel.__init__(self, config, **kwargs)
1945
+ self.model: E1Model = E1Model(config)
1946
+ self.vocab_size = config.vocab_size
1947
+ self.num_labels = config.num_labels
1948
+ self.classifier = nn.Sequential(
1949
+ nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
1950
+ nn.GELU(),
1951
+ nn.LayerNorm(config.hidden_size * 4),
1952
+ nn.Linear(config.hidden_size * 4, config.num_labels),
1953
+ )
1954
+ self.mse = nn.MSELoss()
1955
+ self.ce = nn.CrossEntropyLoss()
1956
+ self.bce = nn.BCEWithLogitsLoss()
1957
+ self.gradient_checkpointing = config.gradient_checkpointing
1958
+ self.prep_tokens = E1BatchPreparer()
1959
+
1960
+ if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
1961
+ pooling_types = kwargs['pooling_types']
1962
+ else:
1963
+ pooling_types = ['mean', 'var']
1964
+ self.pooler = Pooler(pooling_types)
1965
+ self.post_init()
1966
+
1967
+ @property
1968
+ def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
1969
+ return self.model.device_mesh
1970
+
1971
+ @torch.inference_mode()
1972
+ def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
1973
+ batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
1974
+ last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
1975
+ if return_attention_mask:
1976
+ attention_mask = (batch['sequence_ids'] != -1).long()
1977
+ return last_hidden_state, attention_mask
1978
+ else:
1979
+ return last_hidden_state
1980
+
1981
+ def forward(
1982
+ self,
1983
+ input_ids: torch.LongTensor,
1984
+ within_seq_position_ids: torch.LongTensor,
1985
+ global_position_ids: torch.LongTensor,
1986
+ sequence_ids: torch.LongTensor,
1987
+ labels: torch.LongTensor | None = None,
1988
+ past_key_values: DynamicCache | None = None,
1989
+ use_cache: bool = False,
1990
+ output_attentions: bool = False,
1991
+ output_hidden_states: bool = False,
1992
+ **kwargs,
1993
+ ) -> E1ClassificationOutputWithPast:
1994
+ outputs: E1ModelOutputWithPast = self.model(
1995
+ input_ids=input_ids,
1996
+ within_seq_position_ids=within_seq_position_ids,
1997
+ global_position_ids=global_position_ids,
1998
+ sequence_ids=sequence_ids,
1999
+ past_key_values=past_key_values,
2000
+ use_cache=use_cache,
2001
+ output_attentions=output_attentions,
2002
+ output_hidden_states=output_hidden_states,
2003
+ )
2004
+
2005
+ attention_mask = (sequence_ids != -1).long()
2006
+ x = outputs.last_hidden_state
2007
+ features = self.pooler(x, attention_mask)
2008
+ logits = self.classifier(features)
2009
+ loss = None
2010
+ if labels is not None:
2011
+ labels = labels.to(logits.device)
2012
+ if self.config.problem_type is None:
2013
+ if self.num_labels == 1:
2014
+ self.config.problem_type = "regression"
2015
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
2016
+ self.config.problem_type = "single_label_classification"
2017
+ else:
2018
+ self.config.problem_type = "multi_label_classification"
2019
+
2020
+ if self.config.problem_type == "regression":
2021
+ if self.num_labels == 1:
2022
+ loss = self.mse(logits.flatten(), labels.flatten())
2023
+ else:
2024
+ loss = self.mse(logits, labels)
2025
+ elif self.config.problem_type == "single_label_classification":
2026
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
2027
+ elif self.config.problem_type == "multi_label_classification":
2028
+ loss = self.bce(logits, labels)
2029
+
2030
+ return E1ClassificationOutputWithPast(
2031
+ loss=loss,
2032
+ logits=logits,
2033
+ last_hidden_state=x,
2034
+ past_key_values=outputs.past_key_values,
2035
+ hidden_states=outputs.hidden_states,
2036
+ attentions=outputs.attentions,
2037
+ )
2038
+
2039
+
2040
+ class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin):
2041
+ config: E1Config
2042
+ config_class = E1Config
2043
+ def __init__(self, config: E1Config, **kwargs):
2044
+ E1PreTrainedModel.__init__(self, config, **kwargs)
2045
+ self.model: E1Model = E1Model(config)
2046
+ self.vocab_size = config.vocab_size
2047
+ self.num_labels = config.num_labels
2048
+ self.classifier = nn.Sequential(
2049
+ nn.Linear(config.hidden_size * 2, config.hidden_size * 4),
2050
+ nn.GELU(),
2051
+ nn.LayerNorm(config.hidden_size * 4),
2052
+ nn.Linear(config.hidden_size * 4, config.num_labels),
2053
+ )
2054
+ self.loss_fct = nn.CrossEntropyLoss()
2055
+ self.gradient_checkpointing = config.gradient_checkpointing
2056
+ self.prep_tokens = E1BatchPreparer()
2057
+ self.post_init()
2058
+
2059
+ @property
2060
+ def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh:
2061
+ return self.model.device_mesh
2062
+
2063
+ @torch.inference_mode()
2064
+ def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor:
2065
+ batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device)
2066
+ last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
2067
+ if return_attention_mask:
2068
+ attention_mask = (batch['sequence_ids'] != -1).long()
2069
+ return last_hidden_state, attention_mask
2070
+ else:
2071
+ return last_hidden_state
2072
+
2073
+ def forward(
2074
+ self,
2075
+ input_ids: torch.LongTensor,
2076
+ within_seq_position_ids: torch.LongTensor,
2077
+ global_position_ids: torch.LongTensor,
2078
+ sequence_ids: torch.LongTensor,
2079
+ labels: torch.LongTensor | None = None,
2080
+ past_key_values: DynamicCache | None = None,
2081
+ use_cache: bool = False,
2082
+ output_attentions: bool = False,
2083
+ output_hidden_states: bool = False,
2084
+ **kwargs,
2085
+ ) -> E1ClassificationOutputWithPast:
2086
+ outputs: E1ModelOutputWithPast = self.model(
2087
+ input_ids=input_ids,
2088
+ within_seq_position_ids=within_seq_position_ids,
2089
+ global_position_ids=global_position_ids,
2090
+ sequence_ids=sequence_ids,
2091
+ past_key_values=past_key_values,
2092
+ use_cache=use_cache,
2093
+ output_attentions=output_attentions,
2094
+ output_hidden_states=output_hidden_states,
2095
+ )
2096
+
2097
+ x = outputs.last_hidden_state
2098
+ logits = self.classifier(x)
2099
+ loss = None
2100
+ if labels is not None:
2101
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2102
+
2103
+ return E1ClassificationOutputWithPast(
2104
+ loss=loss,
2105
+ logits=logits,
2106
+ last_hidden_state=x,
2107
+ past_key_values=outputs.past_key_values,
2108
+ hidden_states=outputs.hidden_states,
2109
+ attentions=outputs.attentions,
2110
+ )
2111
+
2112
+
2113
+ if __name__ == "__main__":
2114
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2115
+ model = E1ForSequenceClassification.from_pretrained("Profluent-Bio/E1-150m", dtype=torch.bfloat16, num_labels=1).eval().to(device)
2116
+ print(model)
2117
+
2118
+ seqs = [
2119
+ "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE",
2120
+ "IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY",
2121
+ "PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL",
2122
+ "SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA",
2123
+ ]
2124
+
2125
+ batch = model.prep_tokens.get_batch_kwargs(seqs, device=device)
2126
+ batch['labels'] = torch.tensor([0.0, 0.0, 0.0, 0.0], device=device)
2127
+
2128
+ last_hidden_state = model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state
2129
+ print(last_hidden_state.shape)