| import numpy as np | |
| from datasets import load_metric | |
| from transformers import logging | |
| import random | |
| wer_metric = load_metric("./model-bin/metrics/wer") | |
| # print(wer_metric) | |
| def compute_metrics_fn(processor): | |
| def compute(pred): | |
| pred_logits = pred.predictions | |
| pred_ids = np.argmax(pred_logits, axis=-1) | |
| pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id | |
| pred_str = processor.batch_decode(pred_ids) | |
| # we do not want to group tokens when computing the metrics | |
| label_str = processor.batch_decode(pred.label_ids, group_tokens=False) | |
| random_idx = random.randint(0, len(label_str)) | |
| logging.get_logger().info( | |
| '\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx])) | |
| wer = wer_metric.compute(predictions=pred_str, references=label_str) | |
| return {"wer": wer} | |
| return compute | |