lhallee commited on
Commit
fd7a6b7
·
verified ·
1 Parent(s): cdc6f1b

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1 -2
modeling_esm_plusplus.py CHANGED
@@ -409,9 +409,8 @@ def get_attention_mask(
409
  if attention_mask is None:
410
  flex_block_mask = None
411
  else:
412
- sequence_ids = torch.where(token_attention_mask, 1, -1)
413
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
414
- return (sequence_ids[batch_idx, q_idx] == sequence_ids[batch_idx, kv_idx]) & (sequence_ids[batch_idx, q_idx] != -1)
415
 
416
  flex_block_mask = create_block_mask(
417
  mask_mod,
 
409
  if attention_mask is None:
410
  flex_block_mask = None
411
  else:
 
412
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
413
+ return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
414
 
415
  flex_block_mask = create_block_mask(
416
  mask_mod,