lhallee commited on
Commit
c7e57f6
·
verified ·
1 Parent(s): 944b48d

Upload modeling_fast_esmfold.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fast_esmfold.py +70 -7
modeling_fast_esmfold.py CHANGED
@@ -112,15 +112,34 @@ def _kernels_flash_forward(
112
  key_states: torch.Tensor,
113
  value_states: torch.Tensor,
114
  causal: bool = False,
 
115
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
116
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
117
  if FLASH_KERNEL_VARIANT == "flash_attn2":
118
- return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
 
 
 
119
  if FLASH_KERNEL_VARIANT == "flash_attn3":
120
  try:
121
- output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
 
 
 
122
  except TypeError:
123
- output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
 
 
 
124
  if isinstance(output, tuple):
125
  return output[0]
126
  return output
@@ -136,14 +155,20 @@ def _kernels_flash_varlen_forward(
136
  max_seqlen_in_batch_q: int,
137
  max_seqlen_in_batch_k: int,
138
  causal: bool = False,
 
139
  ) -> torch.Tensor:
 
 
 
 
 
140
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
141
  if FLASH_KERNEL_VARIANT == "flash_attn2":
142
  return FLASH_KERNEL.varlen_fwd(
143
  q=query_states, k=key_states, v=value_states,
144
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
145
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
146
- is_causal=causal,
147
  )[0]
148
  if FLASH_KERNEL_VARIANT == "flash_attn3":
149
  try:
@@ -151,14 +176,14 @@ def _kernels_flash_varlen_forward(
151
  q=query_states, k=key_states, v=value_states,
152
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
153
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
154
- causal=causal,
155
  )
156
  except TypeError:
157
  output = FLASH_KERNEL.flash_attn_varlen_func(
158
  query_states, key_states, value_states,
159
  cu_seqlens_q, cu_seqlens_k,
160
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
161
- 0.0, None, causal,
162
  )
163
  if isinstance(output, tuple):
164
  return output[0]
@@ -239,7 +264,21 @@ def kernels_flash_attention_func(
239
  value_states: torch.Tensor,
240
  attention_mask_2d: Optional[torch.Tensor] = None,
241
  causal: bool = False,
 
242
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
244
  if not causal and attention_mask_2d is not None:
245
  batch_size, q_len = query_states.shape[:2]
@@ -251,11 +290,13 @@ def kernels_flash_attention_func(
251
  query_states=query_states, key_states=key_states, value_states=value_states,
252
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
253
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
 
254
  )
255
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
256
  else:
257
  return _kernels_flash_forward(
258
- query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
 
259
  )
260
 
261
 
@@ -338,6 +379,25 @@ def get_attention_mask(
338
  attention_mask_4d = attention_mask_2d[:, None, None, :]
339
  return attention_mask_2d, attention_mask_4d, None
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training.
342
 
343
  Usage:
@@ -497,9 +557,12 @@ class EsmSelfAttention(nn.Module):
497
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
498
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
499
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
 
500
  attn_output = kernels_flash_attention_func(
501
  query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
502
  attention_mask_2d=attention_mask_2d, causal=False,
 
503
  )
504
  return rearrange(attn_output, "b s h d -> b s (h d)"), None
505
 
 
112
  key_states: torch.Tensor,
113
  value_states: torch.Tensor,
114
  causal: bool = False,
115
+ softmax_scale: Optional[float] = None,
116
  ) -> torch.Tensor:
117
+ """Flash-attention forward, optionally overriding the softmax scale.
118
+
119
+ When `softmax_scale is None`, the flash kernel applies its default
120
+ `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
121
+ pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
122
+ Failing to override when Q is pre-scaled produces DOUBLE scaling and
123
+ catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed
124
+ as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa.
125
+ """
126
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
127
  if FLASH_KERNEL_VARIANT == "flash_attn2":
128
+ return FLASH_KERNEL.fwd(
129
+ q=query_states, k=key_states, v=value_states,
130
+ softmax_scale=softmax_scale, is_causal=causal,
131
+ )[0]
132
  if FLASH_KERNEL_VARIANT == "flash_attn3":
133
  try:
134
+ output = FLASH_KERNEL.flash_attn_func(
135
+ q=query_states, k=key_states, v=value_states,
136
+ softmax_scale=softmax_scale, causal=causal,
137
+ )
138
  except TypeError:
139
+ output = FLASH_KERNEL.flash_attn_func(
140
+ query_states, key_states, value_states,
141
+ 0.0, softmax_scale, causal,
142
+ )
143
  if isinstance(output, tuple):
144
  return output[0]
145
  return output
 
155
  max_seqlen_in_batch_q: int,
156
  max_seqlen_in_batch_k: int,
157
  causal: bool = False,
158
+ softmax_scale: Optional[float] = None,
159
  ) -> torch.Tensor:
160
+ """Varlen flash-attention forward, optionally overriding the softmax scale.
161
+
162
+ See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
163
+ passed when Q has been pre-scaled by the caller.
164
+ """
165
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
166
  if FLASH_KERNEL_VARIANT == "flash_attn2":
167
  return FLASH_KERNEL.varlen_fwd(
168
  q=query_states, k=key_states, v=value_states,
169
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
170
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
171
+ softmax_scale=softmax_scale, is_causal=causal,
172
  )[0]
173
  if FLASH_KERNEL_VARIANT == "flash_attn3":
174
  try:
 
176
  q=query_states, k=key_states, v=value_states,
177
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
178
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
179
+ softmax_scale=softmax_scale, causal=causal,
180
  )
181
  except TypeError:
182
  output = FLASH_KERNEL.flash_attn_varlen_func(
183
  query_states, key_states, value_states,
184
  cu_seqlens_q, cu_seqlens_k,
185
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
186
+ 0.0, softmax_scale, causal,
187
  )
188
  if isinstance(output, tuple):
189
  return output[0]
 
264
  value_states: torch.Tensor,
265
  attention_mask_2d: Optional[torch.Tensor] = None,
266
  causal: bool = False,
267
+ softmax_scale: Optional[float] = None,
268
  ) -> torch.Tensor:
269
+ """Public flash-attention entry point with optional padding handling.
270
+
271
+ `softmax_scale`:
272
+ None -> kernel applies its default `1 / sqrt(head_dim)`.
273
+ float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
274
+ by the caller).
275
+
276
+ IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling
277
+ this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass
278
+ `softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP
279
+ of the caller's, producing effective scale `1/head_dim` and catastrophic
280
+ downstream drift that compounds across layers.
281
+ """
282
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
283
  if not causal and attention_mask_2d is not None:
284
  batch_size, q_len = query_states.shape[:2]
 
290
  query_states=query_states, key_states=key_states, value_states=value_states,
291
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
292
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
293
+ softmax_scale=softmax_scale,
294
  )
295
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
296
  else:
297
  return _kernels_flash_forward(
298
+ query_states=query_states, key_states=key_states, value_states=value_states,
299
+ causal=causal, softmax_scale=softmax_scale,
300
  )
301
 
302
 
 
379
  attention_mask_4d = attention_mask_2d[:, None, None, :]
380
  return attention_mask_2d, attention_mask_4d, None
381
 
382
+
383
+ def bool_to_additive_mask(
384
+ bool_mask: torch.Tensor,
385
+ dtype: torch.dtype,
386
+ ) -> torch.Tensor:
387
+ """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
388
+
389
+ Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
390
+ directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and
391
+ silently drops the mask entirely. Always allocate a float tensor first, then fill it.
392
+ This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
393
+ """
394
+ assert bool_mask.dtype == torch.bool, (
395
+ f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
396
+ )
397
+ additive = torch.zeros_like(bool_mask, dtype=dtype)
398
+ additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
399
+ return additive
400
+
401
  """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training.
402
 
403
  Usage:
 
557
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
558
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
559
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
560
+ # Q is pre-scaled by self.scale in forward() -- pass softmax_scale=1.0
561
+ # to prevent the kernel from applying its default 1/sqrt(head_dim).
562
  attn_output = kernels_flash_attention_func(
563
  query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
564
  attention_mask_2d=attention_mask_2d, causal=False,
565
+ softmax_scale=1.0,
566
  )
567
  return rearrange(attn_output, "b s h d -> b s (h d)"), None
568