Training in progress - step 1000
Browse files- asr_modeling.py +4 -2
asr_modeling.py
CHANGED
|
@@ -110,7 +110,8 @@ def apply_specaugment(
|
|
| 110 |
augmented = input_features.clone()
|
| 111 |
|
| 112 |
# Time masking (along time dimension)
|
| 113 |
-
if
|
|
|
|
| 114 |
time_mask = _compute_mask_indices(
|
| 115 |
shape=(batch_size, time_steps),
|
| 116 |
mask_prob=mask_time_prob,
|
|
@@ -123,7 +124,8 @@ def apply_specaugment(
|
|
| 123 |
augmented = augmented.masked_fill(time_mask, 0.0)
|
| 124 |
|
| 125 |
# Frequency masking (along mel dimension)
|
| 126 |
-
if
|
|
|
|
| 127 |
feature_mask = _compute_mask_indices(
|
| 128 |
shape=(batch_size, n_mels),
|
| 129 |
mask_prob=mask_feature_prob,
|
|
|
|
| 110 |
augmented = input_features.clone()
|
| 111 |
|
| 112 |
# Time masking (along time dimension)
|
| 113 |
+
# Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
|
| 114 |
+
if mask_time_prob > 0 or mask_time_min_masks > 0:
|
| 115 |
time_mask = _compute_mask_indices(
|
| 116 |
shape=(batch_size, time_steps),
|
| 117 |
mask_prob=mask_time_prob,
|
|
|
|
| 124 |
augmented = augmented.masked_fill(time_mask, 0.0)
|
| 125 |
|
| 126 |
# Frequency masking (along mel dimension)
|
| 127 |
+
# Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
|
| 128 |
+
if mask_feature_prob > 0 or mask_feature_min_masks > 0:
|
| 129 |
feature_mask = _compute_mask_indices(
|
| 130 |
shape=(batch_size, n_mels),
|
| 131 |
mask_prob=mask_feature_prob,
|