test
Browse files
run_flax_speech_recognition_seq2seq_streaming_v3.py
CHANGED
|
@@ -975,7 +975,8 @@ def main():
|
|
| 975 |
eval_preds.extend(jax.device_get(
|
| 976 |
generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
| 977 |
eval_labels.extend(labels)
|
| 978 |
-
|
|
|
|
| 979 |
# normalize eval metrics
|
| 980 |
eval_metrics = get_metrics(eval_metrics)
|
| 981 |
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
|
|
| 975 |
eval_preds.extend(jax.device_get(
|
| 976 |
generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
| 977 |
eval_labels.extend(labels)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
# normalize eval metrics
|
| 981 |
eval_metrics = get_metrics(eval_metrics)
|
| 982 |
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|