Commit ·
a9c7b15
1
Parent(s): df75f84
correct
Browse files
flax_wav2vec2/run_pretraining_loss.py
CHANGED
|
@@ -9,6 +9,7 @@ from fairseq.criterions.wav2vec_criterion import Wav2VecCriterionConfig, Wav2vec
|
|
| 9 |
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
|
| 10 |
|
| 11 |
from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor
|
|
|
|
| 12 |
|
| 13 |
hf_path = str(sys.argv[1])
|
| 14 |
fairseq_wav2vec2_path = str(sys.argv[2])
|
|
@@ -16,14 +17,15 @@ fairseq_wav2vec2_path = str(sys.argv[2])
|
|
| 16 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
|
| 17 |
|
| 18 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 19 |
-
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
|
| 20 |
|
| 21 |
model = model[0]
|
|
|
|
|
|
|
| 22 |
model.cfg["attention_dropout"] = 0.0
|
| 23 |
model.cfg["dropout_input"] = 0.0
|
| 24 |
model.cfg["dropout_features"] = 0.0
|
| 25 |
model.train()
|
| 26 |
-
print(model)
|
| 27 |
|
| 28 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 29 |
|
|
@@ -42,7 +44,8 @@ attention_mask = inputs.attention_mask
|
|
| 42 |
|
| 43 |
audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
|
| 44 |
task = AudioPretrainingTask.setup_task(audio_cfg)
|
| 45 |
-
criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"], loss_weights=[0.1,
|
|
|
|
| 46 |
|
| 47 |
sample = {
|
| 48 |
"net_input": {
|
|
@@ -53,9 +56,35 @@ sample = {
|
|
| 53 |
}
|
| 54 |
|
| 55 |
torch.manual_seed(0)
|
| 56 |
-
loss, sample_size, log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
torch.manual_seed(0)
|
| 58 |
-
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=
|
| 59 |
|
| 60 |
loss.backward()
|
| 61 |
hf_result.loss.backward()
|
|
@@ -63,6 +92,19 @@ hf_result.loss.backward()
|
|
| 63 |
print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
|
| 64 |
print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
|
|
|
|
|
|
|
|
|
| 9 |
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
|
| 10 |
|
| 11 |
from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor
|
| 12 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
| 13 |
|
| 14 |
hf_path = str(sys.argv[1])
|
| 15 |
fairseq_wav2vec2_path = str(sys.argv[2])
|
|
|
|
| 17 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
|
| 18 |
|
| 19 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 20 |
+
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path, diversity_loss_weight=0.0).train()
|
| 21 |
|
| 22 |
model = model[0]
|
| 23 |
+
# set those to 0.0 in the original fairseq code model code
|
| 24 |
+
# also make sure that numpy uses same random seed
|
| 25 |
model.cfg["attention_dropout"] = 0.0
|
| 26 |
model.cfg["dropout_input"] = 0.0
|
| 27 |
model.cfg["dropout_features"] = 0.0
|
| 28 |
model.train()
|
|
|
|
| 29 |
|
| 30 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 31 |
|
|
|
|
| 44 |
|
| 45 |
audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
|
| 46 |
task = AudioPretrainingTask.setup_task(audio_cfg)
|
| 47 |
+
#criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"]), task, loss_weights=[0.1, 0.0])
|
| 48 |
+
criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"]), task, loss_weights=[0.0, 0.0])
|
| 49 |
|
| 50 |
sample = {
|
| 51 |
"net_input": {
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
torch.manual_seed(0)
|
| 59 |
+
loss, sample_size, log = criterion(model, sample)
|
| 60 |
+
|
| 61 |
+
if attention_mask is not None:
|
| 62 |
+
mask_indices_seq_length = hf_model._get_feat_extract_output_lengths(input_values.shape[-1])
|
| 63 |
+
batch_size = input_values.shape[0]
|
| 64 |
+
# compute real output lengths according to convolution formula
|
| 65 |
+
output_lengths = hf_model._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(
|
| 66 |
+
torch.long
|
| 67 |
+
)
|
| 68 |
+
sub_attention_mask = torch.zeros(
|
| 69 |
+
(batch_size, mask_indices_seq_length), dtype=torch.long, device=input_values.device
|
| 70 |
+
)
|
| 71 |
+
# these two operations makes sure that all values
|
| 72 |
+
# before the output lengths indices are attended to
|
| 73 |
+
sub_attention_mask[
|
| 74 |
+
(torch.arange(sub_attention_mask.shape[0], device=input_values.device), output_lengths - 1)
|
| 75 |
+
] = 1
|
| 76 |
+
sub_attention_mask = sub_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 77 |
+
# sample randomly masked indices
|
| 78 |
+
mask_time_indices = _compute_mask_indices(
|
| 79 |
+
(batch_size, mask_indices_seq_length),
|
| 80 |
+
hf_model.config.mask_time_prob,
|
| 81 |
+
hf_model.config.mask_time_length,
|
| 82 |
+
attention_mask=sub_attention_mask,
|
| 83 |
+
)
|
| 84 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device)
|
| 85 |
+
|
| 86 |
torch.manual_seed(0)
|
| 87 |
+
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices)
|
| 88 |
|
| 89 |
loss.backward()
|
| 90 |
hf_result.loss.backward()
|
|
|
|
| 92 |
print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
|
| 93 |
print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
|
| 94 |
|
| 95 |
+
def grad_norm(model):
|
| 96 |
+
total_norm = 0.0
|
| 97 |
+
for p in model.parameters():
|
| 98 |
+
if p.grad is not None:
|
| 99 |
+
param_norm = p.grad.detach().data.norm(2)
|
| 100 |
+
total_norm += param_norm.item() ** 2
|
| 101 |
+
total_norm = total_norm ** 0.5
|
| 102 |
+
return total_norm
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
print("Fsq grad norm", grad_norm(model))
|
| 106 |
+
print("HF grad norm", grad_norm(hf_model))
|
| 107 |
|
| 108 |
print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
|
| 109 |
+
|
| 110 |
+
print("Grad max/min diff first layer 'feature_extractor.conv_layers[-1].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[-1].conv.weight.grad - model.feature_extractor.conv_layers[-1][0].weight.grad).abs().max())
|