lhallee commited on
Commit
ce5b02d
·
verified ·
1 Parent(s): 238b112

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +15 -11
modeling_dplm2.py CHANGED
@@ -426,17 +426,21 @@ def get_attention_mask(
426
  if attn_backend == "flex":
427
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
428
 
429
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
430
- return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
431
-
432
- flex_block_mask = create_block_mask(
433
- mask_mod,
434
- batch_size,
435
- 1,
436
- seq_len,
437
- seq_len,
438
- device=device,
439
- )
 
 
 
 
440
  extended_attention_mask = None
441
  else:
442
  flex_block_mask = None
 
426
  if attn_backend == "flex":
427
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
428
 
429
+ if attention_mask is None:
430
+ flex_block_mask = None
431
+ else:
432
+ sequence_ids = torch.where(token_attention_mask, 1, -1)
433
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
434
+ return (sequence_ids[batch_idx, q_idx] == sequence_ids[batch_idx, kv_idx]) & (sequence_ids[batch_idx, q_idx] != -1)
435
+
436
+ flex_block_mask = create_block_mask(
437
+ mask_mod,
438
+ batch_size,
439
+ 1,
440
+ seq_len,
441
+ seq_len,
442
+ device=device,
443
+ )
444
  extended_attention_mask = None
445
  else:
446
  flex_block_mask = None