irodkin commited on
Commit
8eadfe3
·
verified ·
1 Parent(s): 4d6c6fe

Training checkpoint at step 9000

Browse files
Files changed (1) hide show
  1. modeling_armt.py +3 -3
modeling_armt.py CHANGED
@@ -440,8 +440,6 @@ def attn_mask_to_4d(attn_mask, upper, query_len):
440
  return mask
441
 
442
  def invert_attn_mask(attn_mask, dtype):
443
- if os.environ.get("NOT_INVERT_ATTN_MASK"):
444
- return attn_mask
445
  min_dtype = torch.finfo(dtype).min
446
  # Use the same dtype as attn_mask to avoid dtype conversion
447
  one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
@@ -1829,7 +1827,7 @@ except Exception as e:
1829
  raise e
1830
 
1831
  # Reuse utilities from the existing implementation to ensure identical math
1832
- # inlined language_modeling: removed import DPFP, invert_attn_mask, attn_mask_to_4d
1833
 
1834
  def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor:
1835
  if os.environ.get("NOT_INVERT_ATTN_MASK"):
@@ -1856,6 +1854,8 @@ def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx:
1856
  return True
1857
  return False
1858
 
 
 
1859
  def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
1860
  if not isinstance(t, torch.Tensor):
1861
  return t
 
440
  return mask
441
 
442
  def invert_attn_mask(attn_mask, dtype):
 
 
443
  min_dtype = torch.finfo(dtype).min
444
  # Use the same dtype as attn_mask to avoid dtype conversion
445
  one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
 
1827
  raise e
1828
 
1829
  # Reuse utilities from the existing implementation to ensure identical math
1830
+ # inlined language_modeling: removed import DPFP, invert_attn_mask as _invert_attn_mask, attn_mask_to_4d
1831
 
1832
  def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor:
1833
  if os.environ.get("NOT_INVERT_ATTN_MASK"):
 
1854
  return True
1855
  return False
1856
 
1857
+ invert_attn_mask = lambda mask, dtype: (_invert_attn_mask(mask, dtype) if not os.environ.get("NOT_INVERT_ATTN_MASK") else mask)
1858
+
1859
  def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
1860
  if not isinstance(t, torch.Tensor):
1861
  return t