mazesmazes commited on
Commit
028b931
·
verified ·
1 Parent(s): 6981450

Training in progress - step 1000

Browse files
Files changed (1) hide show
  1. asr_modeling.py +4 -2
asr_modeling.py CHANGED
@@ -110,7 +110,8 @@ def apply_specaugment(
110
  augmented = input_features.clone()
111
 
112
  # Time masking (along time dimension)
113
- if mask_time_prob > 0:
 
114
  time_mask = _compute_mask_indices(
115
  shape=(batch_size, time_steps),
116
  mask_prob=mask_time_prob,
@@ -123,7 +124,8 @@ def apply_specaugment(
123
  augmented = augmented.masked_fill(time_mask, 0.0)
124
 
125
  # Frequency masking (along mel dimension)
126
- if mask_feature_prob > 0:
 
127
  feature_mask = _compute_mask_indices(
128
  shape=(batch_size, n_mels),
129
  mask_prob=mask_feature_prob,
 
110
  augmented = input_features.clone()
111
 
112
  # Time masking (along time dimension)
113
+ # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
114
+ if mask_time_prob > 0 or mask_time_min_masks > 0:
115
  time_mask = _compute_mask_indices(
116
  shape=(batch_size, time_steps),
117
  mask_prob=mask_time_prob,
 
124
  augmented = augmented.masked_fill(time_mask, 0.0)
125
 
126
  # Frequency masking (along mel dimension)
127
+ # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0)
128
+ if mask_feature_prob > 0 or mask_feature_min_masks > 0:
129
  feature_mask = _compute_mask_indices(
130
  shape=(batch_size, n_mels),
131
  mask_prob=mask_feature_prob,