Training checkpoint at step 2000
Browse files- modeling_armt.py +7 -7
modeling_armt.py
CHANGED
|
@@ -440,13 +440,11 @@ def attn_mask_to_4d(attn_mask, upper, query_len):
|
|
| 440 |
return mask
|
| 441 |
|
| 442 |
def invert_attn_mask(attn_mask, dtype):
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
new_mask = (one - attn_mask) * min_dtype
|
| 449 |
-
return new_mask
|
| 450 |
|
| 451 |
|
| 452 |
|
|
@@ -1856,6 +1854,8 @@ def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx:
|
|
| 1856 |
return True
|
| 1857 |
return False
|
| 1858 |
|
|
|
|
|
|
|
| 1859 |
def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
|
| 1860 |
if not isinstance(t, torch.Tensor):
|
| 1861 |
return t
|
|
|
|
| 440 |
return mask
|
| 441 |
|
| 442 |
def invert_attn_mask(attn_mask, dtype):
|
| 443 |
+
min_dtype = torch.finfo(dtype).min
|
| 444 |
+
# Use the same dtype as attn_mask to avoid dtype conversion
|
| 445 |
+
one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device)
|
| 446 |
+
new_mask = (one - attn_mask) * min_dtype
|
| 447 |
+
return new_mask
|
|
|
|
|
|
|
| 448 |
|
| 449 |
|
| 450 |
|
|
|
|
| 1854 |
return True
|
| 1855 |
return False
|
| 1856 |
|
| 1857 |
+
invert_attn_mask = lambda mask, dtype: (_invert_attn_mask(mask, dtype) if not os.environ.get("NOT_INVERT_ATTN_MASK") else mask)
|
| 1858 |
+
|
| 1859 |
def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor:
|
| 1860 |
if not isinstance(t, torch.Tensor):
|
| 1861 |
return t
|