lhallee commited on
Commit
65fa067
·
verified ·
1 Parent(s): 87d29d9

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +11 -9
modeling_esm_plusplus.py CHANGED
@@ -399,9 +399,9 @@ def get_attention_mask(
399
  attention_mask: Optional[torch.Tensor] = None
400
  ) -> torch.Tensor:
401
  if attention_mask is None:
402
- token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
403
  else:
404
- token_attention_mask = attention_mask.bool()
405
 
406
  if attn_backend == "flex":
407
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
@@ -409,8 +409,10 @@ def get_attention_mask(
409
  if attention_mask is None:
410
  flex_block_mask = None
411
  else:
 
 
412
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
413
- return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
414
 
415
  flex_block_mask = create_block_mask(
416
  mask_mod,
@@ -420,12 +422,12 @@ def get_attention_mask(
420
  seq_len,
421
  device=device,
422
  )
423
- extended_attention_mask = None
424
  else:
425
  flex_block_mask = None
426
- extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
427
 
428
- return extended_attention_mask, flex_block_mask
429
 
430
 
431
  class ESMplusplusConfig(PretrainedConfig):
@@ -938,7 +940,7 @@ class TransformerStack(nn.Module):
938
  attentions = () if output_attentions else None
939
 
940
  # move to 4D attention mask or flex block mask
941
- attention_mask, flex_block_mask = get_attention_mask(
942
  attn_backend=self._attn_backend,
943
  batch_size=x.shape[0],
944
  seq_len=x.shape[1],
@@ -951,14 +953,14 @@ class TransformerStack(nn.Module):
951
  x, attn_weights = self._gradient_checkpointing_func(
952
  block.__call__,
953
  x=x,
954
- attention_mask=attention_mask,
955
  flex_block_mask=flex_block_mask,
956
  output_attentions=output_attentions,
957
  )
958
  else:
959
  x, attn_weights = block(
960
  x=x,
961
- attention_mask=attention_mask,
962
  flex_block_mask=flex_block_mask,
963
  output_attentions=output_attentions,
964
  )
 
399
  attention_mask: Optional[torch.Tensor] = None
400
  ) -> torch.Tensor:
401
  if attention_mask is None:
402
+ attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
403
  else:
404
+ attention_mask_2d = attention_mask.bool()
405
 
406
  if attn_backend == "flex":
407
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
 
409
  if attention_mask is None:
410
  flex_block_mask = None
411
  else:
412
+ valid_lens = attention_mask_2d.sum(dim=-1)
413
+
414
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
415
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
416
 
417
  flex_block_mask = create_block_mask(
418
  mask_mod,
 
422
  seq_len,
423
  device=device,
424
  )
425
+ attention_mask_4d = None
426
  else:
427
  flex_block_mask = None
428
+ attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
429
 
430
+ return attention_mask_4d, flex_block_mask
431
 
432
 
433
  class ESMplusplusConfig(PretrainedConfig):
 
940
  attentions = () if output_attentions else None
941
 
942
  # move to 4D attention mask or flex block mask
943
+ attention_mask_4d, flex_block_mask = get_attention_mask(
944
  attn_backend=self._attn_backend,
945
  batch_size=x.shape[0],
946
  seq_len=x.shape[1],
 
953
  x, attn_weights = self._gradient_checkpointing_func(
954
  block.__call__,
955
  x=x,
956
+ attention_mask=attention_mask_4d,
957
  flex_block_mask=flex_block_mask,
958
  output_attentions=output_attentions,
959
  )
960
  else:
961
  x, attn_weights = block(
962
  x=x,
963
+ attention_mask=attention_mask_4d,
964
  flex_block_mask=flex_block_mask,
965
  output_attentions=output_attentions,
966
  )