lhallee commited on
Commit
a042556
·
verified ·
1 Parent(s): 63fb721

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +15 -11
modeling_dplm.py CHANGED
@@ -427,17 +427,21 @@ def get_attention_mask(
427
  if attn_backend == "flex":
428
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
429
 
430
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
431
- return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
432
-
433
- flex_block_mask = create_block_mask(
434
- mask_mod,
435
- batch_size,
436
- 1,
437
- seq_len,
438
- seq_len,
439
- device=device,
440
- )
 
 
 
 
441
  extended_attention_mask = None
442
  else:
443
  flex_block_mask = None
 
427
  if attn_backend == "flex":
428
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
429
 
430
+ if attention_mask is None:
431
+ flex_block_mask = None
432
+ else:
433
+ sequence_ids = torch.where(token_attention_mask, 1, -1)
434
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
435
+ return (sequence_ids[batch_idx, q_idx] == sequence_ids[batch_idx, kv_idx]) & (sequence_ids[batch_idx, q_idx] != -1)
436
+
437
+ flex_block_mask = create_block_mask(
438
+ mask_mod,
439
+ batch_size,
440
+ 1,
441
+ seq_len,
442
+ seq_len,
443
+ device=device,
444
+ )
445
  extended_attention_mask = None
446
  else:
447
  flex_block_mask = None