Commit ·
df75f84
1
Parent(s): 5d32d31
up
Browse files
flax_wav2vec2/logits.npy
CHANGED
|
Binary files a/flax_wav2vec2/logits.npy and b/flax_wav2vec2/logits.npy differ
|
|
|
flax_wav2vec2/run_pretraining_loss.py
CHANGED
|
@@ -16,10 +16,14 @@ fairseq_wav2vec2_path = str(sys.argv[2])
|
|
| 16 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
|
| 17 |
|
| 18 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 19 |
-
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path)
|
| 20 |
|
| 21 |
model = model[0]
|
| 22 |
-
model.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 25 |
|
|
@@ -51,9 +55,14 @@ sample = {
|
|
| 51 |
torch.manual_seed(0)
|
| 52 |
loss, sample_size, log, result = criterion(model, sample)
|
| 53 |
torch.manual_seed(0)
|
| 54 |
-
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"]
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
|
| 57 |
print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
|
| 58 |
|
| 59 |
print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
|
|
|
|
|
|
|
|
|
| 16 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
|
| 17 |
|
| 18 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 19 |
+
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
|
| 20 |
|
| 21 |
model = model[0]
|
| 22 |
+
model.cfg["attention_dropout"] = 0.0
|
| 23 |
+
model.cfg["dropout_input"] = 0.0
|
| 24 |
+
model.cfg["dropout_features"] = 0.0
|
| 25 |
+
model.train()
|
| 26 |
+
print(model)
|
| 27 |
|
| 28 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 29 |
|
|
|
|
| 55 |
torch.manual_seed(0)
|
| 56 |
loss, sample_size, log, result = criterion(model, sample)
|
| 57 |
torch.manual_seed(0)
|
| 58 |
+
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"].detach())
|
| 59 |
+
|
| 60 |
+
loss.backward()
|
| 61 |
+
hf_result.loss.backward()
|
| 62 |
|
| 63 |
print("Loss diff %", 100 * (loss.detach().item() - hf_result.loss.detach().item()) / hf_result.loss.detach())
|
| 64 |
print("Loss diff abs", (loss.detach().item() - hf_result.loss.detach().item()))
|
| 65 |
|
| 66 |
print("perplexity diff %", 100 * (hf_result.codevector_perplexity.detach().item() - result["prob_perplexity"].detach().item()) / hf_result.codevector_perplexity.detach())
|
| 67 |
+
|
| 68 |
+
print("Grad max/min diff first layer 'feature_extractor.conv_layers[0].conv.weight'", (hf_model.wav2vec2.feature_extractor.conv_layers[0].conv.weight.grad - model.feature_extractor.conv_layers[0][0].weight.grad).abs().max())
|
flax_wav2vec2/run_pretraining_loss_flax.py
CHANGED
|
@@ -4,6 +4,7 @@ import fairseq
|
|
| 4 |
import torch
|
| 5 |
import optax
|
| 6 |
import jax.numpy as jnp
|
|
|
|
| 7 |
from flax.training.common_utils import onehot
|
| 8 |
|
| 9 |
import soundfile as sf
|
|
@@ -20,10 +21,10 @@ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairse
|
|
| 20 |
|
| 21 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 22 |
flax_hf_model = FlaxWav2Vec2ForPreTraining.from_pretrained(hf_path)
|
| 23 |
-
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path)
|
| 24 |
|
| 25 |
model = model[0]
|
| 26 |
-
model.
|
| 27 |
|
| 28 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 29 |
|
|
@@ -54,7 +55,7 @@ def compute_contrastive_loss(
|
|
| 54 |
neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
|
| 55 |
|
| 56 |
# make sure incorrectly sampled vectors don't contribute to loss
|
| 57 |
-
loss_logits = jnp.where(neg_is_pos, -
|
| 58 |
|
| 59 |
predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
|
| 60 |
targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
|
|
@@ -88,9 +89,11 @@ sample = {
|
|
| 88 |
torch.manual_seed(0)
|
| 89 |
loss, sample_size, log, result = criterion(model, sample)
|
| 90 |
torch.manual_seed(0)
|
| 91 |
-
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"])
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
negative_indices = hf_result["sampled_negative_indices"].detach().numpy()
|
| 96 |
num_negatives = 100
|
|
@@ -114,3 +117,5 @@ loss = contrastive_loss + diversity_loss_weight * diversity_loss
|
|
| 114 |
|
| 115 |
|
| 116 |
print("Loss diff %", 100 * (hf_result.loss.detach().item() - loss.item()) / loss)
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import optax
|
| 6 |
import jax.numpy as jnp
|
| 7 |
+
import jax
|
| 8 |
from flax.training.common_utils import onehot
|
| 9 |
|
| 10 |
import soundfile as sf
|
|
|
|
| 21 |
|
| 22 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path, do_normalize=False)
|
| 23 |
flax_hf_model = FlaxWav2Vec2ForPreTraining.from_pretrained(hf_path)
|
| 24 |
+
hf_model = Wav2Vec2ForPreTraining.from_pretrained(hf_path).train()
|
| 25 |
|
| 26 |
model = model[0]
|
| 27 |
+
model.train()
|
| 28 |
|
| 29 |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
| 30 |
|
|
|
|
| 55 |
neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
|
| 56 |
|
| 57 |
# make sure incorrectly sampled vectors don't contribute to loss
|
| 58 |
+
loss_logits = jnp.where(neg_is_pos, -2**30, loss_logits)
|
| 59 |
|
| 60 |
predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
|
| 61 |
targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
|
|
|
|
| 89 |
torch.manual_seed(0)
|
| 90 |
loss, sample_size, log, result = criterion(model, sample)
|
| 91 |
torch.manual_seed(0)
|
| 92 |
+
hf_result = hf_model(input_values, attention_mask=attention_mask, mask_time_indices=result["mask_indices"], code_vec_indices=result["code_idxs"])
|
| 93 |
|
| 94 |
+
print(100 * "=")
|
| 95 |
+
|
| 96 |
+
outputs = flax_hf_model(input_values.numpy(), attention_mask=attention_mask.numpy(), mask_time_indices=result["mask_indices"].numpy(), train=True, gumbel_rng=jax.random.PRNGKey(0), code_vec_indices=result["code_idxs"])
|
| 97 |
|
| 98 |
negative_indices = hf_result["sampled_negative_indices"].detach().numpy()
|
| 99 |
num_negatives = 100
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
print("Loss diff %", 100 * (hf_result.loss.detach().item() - loss.item()) / loss)
|
| 120 |
+
|
| 121 |
+
print("perplexity diff %", 100 * (outputs.codevector_perplexity - result["prob_perplexity"].detach().item()) / outputs.codevector_perplexity)
|