Training checkpoint at step 9000
Browse files- modeling_armt.py +3 -3
modeling_armt.py
CHANGED
|
@@ -440,8 +440,6 @@ def attn_mask_to_4d(attn_mask, upper, query_len):
|
|
| 440 |
return mask
|
| 441 |
|
| 442 |
def invert_attn_mask(attn_mask, dtype):
|
| 443 |
-
if os.environ.get("NOT_INVERT_ATTN_MASK"):
|
| 444 |
-
return attn_mask
|
| 445 |
min_dtype = torch.finfo(dtype).min
|
| 446 |
# Use the same dtype as attn_mask to avoid dtype conversion
|
| 447 |
one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
|
|
@@ -1829,7 +1827,7 @@ except Exception as e:
|
|
| 1829 |
raise e
|
| 1830 |
|
| 1831 |
# Reuse utilities from the existing implementation to ensure identical math
|
| 1832 |
-
# inlined language_modeling: removed import DPFP, invert_attn_mask, attn_mask_to_4d
|
| 1833 |
|
| 1834 |
def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor:
|
| 1835 |
if os.environ.get("NOT_INVERT_ATTN_MASK"):
|
|
@@ -1856,6 +1854,8 @@ def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx:
|
|
| 1856 |
return True
|
| 1857 |
return False
|
| 1858 |
|
|
|
|
|
|
|
| 1859 |
def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
|
| 1860 |
if not isinstance(t, torch.Tensor):
|
| 1861 |
return t
|
|
|
|
| 440 |
return mask
|
| 441 |
|
| 442 |
def invert_attn_mask(attn_mask, dtype):
|
|
|
|
|
|
|
| 443 |
min_dtype = torch.finfo(dtype).min
|
| 444 |
# Use the same dtype as attn_mask to avoid dtype conversion
|
| 445 |
one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
|
|
|
|
| 1827 |
raise e
|
| 1828 |
|
| 1829 |
# Reuse utilities from the existing implementation to ensure identical math
|
| 1830 |
+
# inlined language_modeling: removed import DPFP, invert_attn_mask as _invert_attn_mask, attn_mask_to_4d
|
| 1831 |
|
| 1832 |
def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor:
|
| 1833 |
if os.environ.get("NOT_INVERT_ATTN_MASK"):
|
|
|
|
| 1854 |
return True
|
| 1855 |
return False
|
| 1856 |
|
| 1857 |
+
invert_attn_mask = lambda mask, dtype: (_invert_attn_mask(mask, dtype) if not os.environ.get("NOT_INVERT_ATTN_MASK") else mask)
|
| 1858 |
+
|
| 1859 |
def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
|
| 1860 |
if not isinstance(t, torch.Tensor):
|
| 1861 |
return t
|