Patrick von Platen commited on
Commit
15d0d9a
·
1 Parent(s): 3f3a3d2
flax_wav2vec2/branches_to_use.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fairseq: https://github.com/patrickvonplaten/fairseq/tree/save_w2v_pretraining_check (see https://github.com/patrickvonplaten/fairseq/pulls)
2
+ transformers: https://github.com/patrickvonplaten/transformers/tree/debug_wav2vec2_pretraining (see https://github.com/huggingface/transformers/pull/12743)
flax_wav2vec2/run_pretraining_loss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import datasets
3
+ import fairseq
4
+ import torch
5
+
6
+ import soundfile as sf
7
+ import sys
8
+ from fairseq.criterions.wav2vec_criterion import Wav2VecCriterionConfig, Wav2vecCriterion
9
+ from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
10
+
11
+ from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor
12
+
13
+ hf_path = str(sys.argv[1])
14
+ fairseq_wav2vec2_path = str(sys.argv[2])
15
+
16
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
17
+
18
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
19
+ hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path)
20
+
21
+ model = model[0]
22
+ model.eval()
23
+
24
+ dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
25
+
26
+
27
+ def map_to_array(batch):
28
+ speech_array, _ = sf.read(batch["file"])
29
+ batch["speech"] = speech_array
30
+ return batch
31
+
32
+
33
+ dummy_speech_data = dummy_speech_data.map(map_to_array, remove_columns=["file"])
34
+ inputs = feature_extractor(dummy_speech_data[:3]["speech"], return_tensors="pt", padding="longest", return_attention_mask=True)
35
+
36
+ input_values = inputs.input_values
37
+ attention_mask = inputs.attention_mask
38
+
39
+ audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
40
+ task = AudioPretrainingTask.setup_task(audio_cfg)
41
+ criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"], loss_weights=[0.1, 10]), task)
42
+
43
+ sample = {
44
+ "net_input": {
45
+ "source": input_values,
46
+ "padding_mask": attention_mask.ne(1),
47
+ },
48
+ "id": torch.zeros((1,)),
49
+ }
50
+
51
+ torch.manual_seed(0)
52
+ loss, sample_size, log, result = criterion(model, sample)
53
+ torch.manual_seed(0)
54
+ hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"], fsq_negs=result["negs"])
55
+
56
+ print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
57
+
58
+ print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
flax_wav2vec2/run_pt_fsq_comp.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/usr/bin/env bash
2
+ ./run_pretraining_loss.py wav2vec2-large-lv60 wav2vec_vox_new.pt
flax_wav2vec2/wav2vec2-large-lv60 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1f3eef2bbbac0a61cae0cf882fd615fc960a062f
flax_wav2vec2/wav2vec_vox_new.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b0748fbd4c725ff62266e3b9544cf948d117bc7fa2dc49528184de547736844
3
+ size 3174007860
config.json → generation/config.json RENAMED
File without changes
run_flax_pt_generation.py → generation/run_flax_pt_generation.py RENAMED
File without changes