irodkin commited on
Commit
10d0216
·
verified ·
1 Parent(s): 1b7a9f0

Training checkpoint at step 2000

Browse files
Files changed (1) hide show
  1. modeling_armt.py +7 -7
modeling_armt.py CHANGED
@@ -440,13 +440,11 @@ def attn_mask_to_4d(attn_mask, upper, query_len):
440
  return mask
441
 
442
  def invert_attn_mask(attn_mask, dtype):
443
- if os.environ.get("NOT_INVERT_ATTN_MASK"):
444
- return attn_mask
445
- min_dtype = torch.finfo(dtype).min
446
- # Use the same dtype as attn_mask to avoid dtype conversion
447
- one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
448
- new_mask = (one - attn_mask) * min_dtype
449
- return new_mask
450
 
451
 
452
 
@@ -1856,6 +1854,8 @@ def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx:
1856
  return True
1857
  return False
1858
 
 
 
1859
  def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
1860
  if not isinstance(t, torch.Tensor):
1861
  return t
 
440
  return mask
441
 
442
  def invert_attn_mask(attn_mask, dtype):
443
+ min_dtype = torch.finfo(dtype).min
444
+ # Use the same dtype as attn_mask to avoid dtype conversion
445
+ one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
446
+ new_mask = (one - attn_mask) * min_dtype
447
+ return new_mask
 
 
448
 
449
 
450
 
 
1854
  return True
1855
  return False
1856
 
1857
+ invert_attn_mask = lambda mask, dtype: (_invert_attn_mask(mask, dtype) if not os.environ.get("NOT_INVERT_ATTN_MASK") else mask)
1858
+
1859
  def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
1860
  if not isinstance(t, torch.Tensor):
1861
  return t