Commit ·
5021159
1
Parent(s): 1e275bf
add random print sample when eval
Browse files- metric_utils.py +6 -0
metric_utils.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import numpy as np
|
| 2 |
from datasets import load_metric
|
|
|
|
|
|
|
| 3 |
|
| 4 |
wer_metric = load_metric("./model-bin/metrics/wer")
|
| 5 |
|
|
@@ -18,6 +20,10 @@ def compute_metrics_fn(processor):
|
|
| 18 |
# we do not want to group tokens when computing the metrics
|
| 19 |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
| 22 |
|
| 23 |
return {"wer": wer}
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from datasets import load_metric
|
| 3 |
+
from transformers import logging
|
| 4 |
+
import random
|
| 5 |
|
| 6 |
wer_metric = load_metric("./model-bin/metrics/wer")
|
| 7 |
|
|
|
|
| 20 |
# we do not want to group tokens when computing the metrics
|
| 21 |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
| 22 |
|
| 23 |
+
random_idx = random.randint(0, len(label_str))
|
| 24 |
+
logging.get_logger().info(
|
| 25 |
+
'\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx]))
|
| 26 |
+
|
| 27 |
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
| 28 |
|
| 29 |
return {"wer": wer}
|