patrickvonplaten commited on
Commit
a9c7b15
·
1 Parent(s): df75f84
flax_wav2vec2/run_pretraining_loss.py CHANGED
@@ -9,6 +9,7 @@ from fairseq.criterions.wav2vec_criterion import Wav2VecCriterionConfig, Wav2vec
9
  from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
10
 
11
  from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor
 
12
 
13
  hf_path = str(sys.argv[1])
14
  fairseq_wav2vec2_path = str(sys.argv[2])
@@ -16,14 +17,15 @@ fairseq_wav2vec2_path = str(sys.argv[2])
16
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
17
 
18
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
19
- hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
20
 
21
  model = model[0]
 
 
22
  model.cfg["attention_dropout"] = 0.0
23
  model.cfg["dropout_input"] = 0.0
24
  model.cfg["dropout_features"] = 0.0
25
  model.train()
26
- print(model)
27
 
28
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
29
 
@@ -42,7 +44,8 @@ attention_mask = inputs.attention_mask
42
 
43
  audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
44
  task = AudioPretrainingTask.setup_task(audio_cfg)
45
- criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"], loss_weights=[0.1, 10]), task)
 
46
 
47
  sample = {
48
  "net_input": {
@@ -53,9 +56,35 @@ sample = {
53
  }
54
 
55
  torch.manual_seed(0)
56
- loss, sample_size, log, result = criterion(model, sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  torch.manual_seed(0)
58
- hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"].detach())
59
 
60
  loss.backward()
61
  hf_result.loss.backward()
@@ -63,6 +92,19 @@ hf_result.loss.backward()
63
  print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
64
  print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
65
 
66
- print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
 
 
 
9
  from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
10
 
11
  from transformers import Wav2Vec2ForPreTraining, Wav2Vec2FeatureExtractor
12
+ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
13
 
14
  hf_path = str(sys.argv[1])
15
  fairseq_wav2vec2_path = str(sys.argv[2])
 
17
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
18
 
19
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
20
+ hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path, diversity_loss_weight=0.0).train()
21
 
22
  model = model[0]
23
+ # set those to 0.0 in the original fairseq code model code
24
+ # also make sure that numpy uses same random seed
25
  model.cfg["attention_dropout"] = 0.0
26
  model.cfg["dropout_input"] = 0.0
27
  model.cfg["dropout_features"] = 0.0
28
  model.train()
 
29
 
30
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
31
 
 
44
 
45
  audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
46
  task = AudioPretrainingTask.setup_task(audio_cfg)
47
+ #criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"]), task, loss_weights=[0.1, 0.0])
48
+ criterion = Wav2vecCriterion(Wav2VecCriterionConfig(infonce=True, log_keys=["prob_perplexity", "code_perplexity", "temp"]), task, loss_weights=[0.0, 0.0])
49
 
50
  sample = {
51
  "net_input": {
 
56
  }
57
 
58
  torch.manual_seed(0)
59
+ loss, sample_size, log = criterion(model, sample)
60
+
61
+ if attention_mask is not None:
62
+ mask_indices_seq_length = hf_model._get_feat_extract_output_lengths(input_values.shape[-1])
63
+ batch_size = input_values.shape[0]
64
+ # compute real output lengths according to convolution formula
65
+ output_lengths = hf_model._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(
66
+ torch.long
67
+ )
68
+ sub_attention_mask = torch.zeros(
69
+ (batch_size, mask_indices_seq_length), dtype=torch.long, device=input_values.device
70
+ )
71
+ # these two operations makes sure that all values
72
+ # before the output lengths indices are attended to
73
+ sub_attention_mask[
74
+ (torch.arange(sub_attention_mask.shape[0], device=input_values.device), output_lengths - 1)
75
+ ] = 1
76
+ sub_attention_mask = sub_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
77
+ # sample randomly masked indices
78
+ mask_time_indices = _compute_mask_indices(
79
+ (batch_size, mask_indices_seq_length),
80
+ hf_model.config.mask_time_prob,
81
+ hf_model.config.mask_time_length,
82
+ attention_mask=sub_attention_mask,
83
+ )
84
+ mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device)
85
+
86
  torch.manual_seed(0)
87
+ hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices)
88
 
89
  loss.backward()
90
  hf_result.loss.backward()
 
92
  print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
93
  print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
94
 
95
+ def grad_norm(model):
96
+ total_norm = 0.0
97
+ for p in model.parameters():
98
+ if p.grad is not None:
99
+ param_norm = p.grad.detach().data.norm(2)
100
+ total_norm += param_norm.item() ** 2
101
+ total_norm = total_norm ** 0.5
102
+ return total_norm
103
+
104
+
105
+ print("Fsq grad norm", grad_norm(model))
106
+ print("HF grad norm", grad_norm(hf_model))
107
 
108
  print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
109
+
110
+ print("Grad max/min diff first layer 'feature_extractor.conv_layers[-1].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[-1].conv.weight.grad - model.feature_extractor.conv_layers[-1][0].weight.grad).abs().max())