pere commited on
Commit
f12972c
·
1 Parent(s): ce5e747
run_flax_speech_recognition_seq2seq_streaming_v3.py CHANGED
@@ -975,7 +975,8 @@ def main():
975
  eval_preds.extend(jax.device_get(
976
  generated_ids.reshape(-1, gen_kwargs["max_length"])))
977
  eval_labels.extend(labels)
978
- breakpoint()
 
979
  # normalize eval metrics
980
  eval_metrics = get_metrics(eval_metrics)
981
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
 
975
  eval_preds.extend(jax.device_get(
976
  generated_ids.reshape(-1, gen_kwargs["max_length"])))
977
  eval_labels.extend(labels)
978
+
979
+
980
  # normalize eval metrics
981
  eval_metrics = get_metrics(eval_metrics)
982
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)