patrickvonplaten commited on
Commit
df75f84
·
1 Parent(s): 5d32d31
flax_wav2vec2/logits.npy CHANGED
Binary files a/flax_wav2vec2/logits.npy and b/flax_wav2vec2/logits.npy differ
 
flax_wav2vec2/run_pretraining_loss.py CHANGED
@@ -16,10 +16,14 @@ fairseq_wav2vec2_path = str(sys.argv[2])
16
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
17
 
18
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
19
- hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path)
20
 
21
  model = model[0]
22
- model.eval()
 
 
 
 
23
 
24
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
25
 
@@ -51,9 +55,14 @@ sample = {
51
  torch.manual_seed(0)
52
  loss, sample_size, log, result = criterion(model, sample)
53
  torch.manual_seed(0)
54
- hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"], fsq_negs=result["negs"])
 
 
 
55
 
56
  print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
57
  print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
58
 
59
  print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
 
 
 
16
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
17
 
18
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
19
+ hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
20
 
21
  model = model[0]
22
+ model.cfg["attention_dropout"] = 0.0
23
+ model.cfg["dropout_input"] = 0.0
24
+ model.cfg["dropout_features"] = 0.0
25
+ model.train()
26
+ print(model)
27
 
28
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
29
 
 
55
  torch.manual_seed(0)
56
  loss, sample_size, log, result = criterion(model, sample)
57
  torch.manual_seed(0)
58
+ hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"].detach())
59
+
60
+ loss.backward()
61
+ hf_result.loss.backward()
62
 
63
  print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
64
  print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
65
 
66
  print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
67
+
68
+ print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
flax_wav2vec2/run_pretraining_loss_flax.py CHANGED
@@ -4,6 +4,7 @@ import fairseq
4
  import torch
5
  import optax
6
  import jax.numpy as jnp
 
7
  from flax.training.common_utils import onehot
8
 
9
  import soundfile as sf
@@ -20,10 +21,10 @@ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairse
20
 
21
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
22
  flax_hf_model = FlaxWav2Vec2ForPreTraining.from_pretrained(hf_path)
23
- hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path)
24
 
25
  model = model[0]
26
- model.eval()
27
 
28
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
29
 
@@ -54,7 +55,7 @@ def compute_contrastive_loss(
54
  neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
55
 
56
  # make sure incorrectly sampled vectors don't contribute to loss
57
- loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)
58
 
59
  predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
60
  targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
@@ -88,9 +89,11 @@ sample = {
88
  torch.manual_seed(0)
89
  loss, sample_size, log, result = criterion(model, sample)
90
  torch.manual_seed(0)
91
- hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"])
92
 
93
- outputs = flax_hf_model(input_values.numpy(), attention_mask=attention_mask.numpy(), mask_time_indices=result["mask_indices"].numpy())
 
 
94
 
95
  negative_indices = hf_result["sampled_negative_indices"].detach().numpy()
96
  num_negatives = 100
@@ -114,3 +117,5 @@ loss = contrastive_loss + diversity_loss_weight * diversity_loss
114
 
115
 
116
  print("Loss diff %", 100 * (hf_result.loss.detach().item() - loss.item()) / loss)
 
 
 
4
  import torch
5
  import optax
6
  import jax.numpy as jnp
7
+ import jax
8
  from flax.training.common_utils import onehot
9
 
10
  import soundfile as sf
 
21
 
22
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
23
  flax_hf_model = FlaxWav2Vec2ForPreTraining.from_pretrained(hf_path)
24
+ hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
25
 
26
  model = model[0]
27
+ model.train()
28
 
29
  dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
30
 
 
55
  neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
56
 
57
  # make sure incorrectly sampled vectors don't contribute to loss
58
+ loss_logits = jnp.where(neg_is_pos, -2**30, loss_logits)
59
 
60
  predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
61
  targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
 
89
  torch.manual_seed(0)
90
  loss, sample_size, log, result = criterion(model, sample)
91
  torch.manual_seed(0)
92
+ hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"], code_vec_indices=result["code_idxs"])
93
 
94
+ print(100 * "=")
95
+
96
+ outputs = flax_hf_model(input_values.numpy(), attention_mask=attention_mask.numpy(), mask_time_indices=result["mask_indices"].numpy(), train=True, gumbel_rng=jax.random.PRNGKey(0), code_vec_indices=result["code_idxs"])
97
 
98
  negative_indices = hf_result["sampled_negative_indices"].detach().numpy()
99
  num_negatives = 100
 
117
 
118
 
119
  print("Loss diff %", 100 * (hf_result.loss.detach().item() - loss.item()) / loss)
120
+
121
+ print("perplexity diff %", 100 * (outputs.codevector_perplexity - result["prob_perplexity"].detach().item()) / outputs.codevector_perplexity)