upd eval script to compute WER for each sample individually. upd fleurs predictions with WER column
Browse files
predictions/preds_google_fleurs_be_by_test_20221221-101048.tsv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
predictions/preds_google_fleurs_be_by_test_20221221-101048.xlsx
ADDED
|
Binary file (242 kB). View file
|
|
|
src/run_eval_whisper_streaming.py
CHANGED
|
@@ -10,6 +10,7 @@ from transformers import pipeline
|
|
| 10 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
| 11 |
from datasets import load_dataset, Audio
|
| 12 |
import evaluate
|
|
|
|
| 13 |
|
| 14 |
from belarusian_text_normalizer import BelarusianTextNormalizer
|
| 15 |
|
|
@@ -33,6 +34,21 @@ wer_metric = evaluate.load("wer")
|
|
| 33 |
text_normalizer = BelarusianTextNormalizer()
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def is_target_text_in_range(ref):
|
| 37 |
if ref.strip() == "ignore time segment in scoring":
|
| 38 |
return False
|
|
@@ -106,15 +122,30 @@ def main(args):
|
|
| 106 |
logger.info(f'WER: {wer}')
|
| 107 |
|
| 108 |
if args.save_predictions is True:
|
| 109 |
-
preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.
|
| 110 |
preds_fp = clean_filename(preds_fp)
|
| 111 |
logger.info(f'saving predictions to: "{preds_fp}"')
|
|
|
|
| 112 |
preds_df = pd.DataFrame({
|
| 113 |
'audio_path': audio_paths,
|
| 114 |
'prediction_norm': predictions_norm, 'reference_norm': references_norm,
|
| 115 |
'prediction': predictions, 'reference': references,
|
| 116 |
})
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
else:
|
| 119 |
logger.info('save_predictions is False. will not save predictions to a file')
|
| 120 |
|
|
|
|
| 10 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
| 11 |
from datasets import load_dataset, Audio
|
| 12 |
import evaluate
|
| 13 |
+
import jiwer
|
| 14 |
|
| 15 |
from belarusian_text_normalizer import BelarusianTextNormalizer
|
| 16 |
|
|
|
|
| 34 |
text_normalizer = BelarusianTextNormalizer()
|
| 35 |
|
| 36 |
|
| 37 |
+
def pull_columns(df: pd.DataFrame, cols) -> pd.DataFrame:
|
| 38 |
+
""" Pull columns to the beginning of the dataframe """
|
| 39 |
+
if isinstance(cols, str):
|
| 40 |
+
cols = [cols]
|
| 41 |
+
cols = list(cols)
|
| 42 |
+
|
| 43 |
+
absent_cols = list(set(cols).difference(df.columns))
|
| 44 |
+
assert len(absent_cols) == 0, f'{absent_cols} columns are absent in df'
|
| 45 |
+
|
| 46 |
+
cols_rest = [c for c in df.columns if c not in cols]
|
| 47 |
+
new_df = df[cols + cols_rest].copy()
|
| 48 |
+
assert new_df.shape[1] == df.shape[1]
|
| 49 |
+
return new_df
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def is_target_text_in_range(ref):
|
| 53 |
if ref.strip() == "ignore time segment in scoring":
|
| 54 |
return False
|
|
|
|
| 122 |
logger.info(f'WER: {wer}')
|
| 123 |
|
| 124 |
if args.save_predictions is True:
|
| 125 |
+
preds_fp = f'preds_{args.dataset}_{args.config}_{args.split}_{now_str}.xlsx'
|
| 126 |
preds_fp = clean_filename(preds_fp)
|
| 127 |
logger.info(f'saving predictions to: "{preds_fp}"')
|
| 128 |
+
|
| 129 |
preds_df = pd.DataFrame({
|
| 130 |
'audio_path': audio_paths,
|
| 131 |
'prediction_norm': predictions_norm, 'reference_norm': references_norm,
|
| 132 |
'prediction': predictions, 'reference': references,
|
| 133 |
})
|
| 134 |
+
|
| 135 |
+
logger.info('computing WER for each item individually')
|
| 136 |
+
preds_df['wer'] = preds_df.apply(
|
| 137 |
+
lambda row: 100 * jiwer.wer(
|
| 138 |
+
truth=row['reference_norm'], hypothesis=row['prediction_norm']),
|
| 139 |
+
axis=1
|
| 140 |
+
)
|
| 141 |
+
preds_df.sort_values('wer', ascending=False, inplace=True)
|
| 142 |
+
|
| 143 |
+
# use pull_columns instead of direct dataframe indexing
|
| 144 |
+
# not to delete any columns that could be added to dataframe in future.
|
| 145 |
+
cols_order = ['audio_path', 'wer', 'prediction_norm', 'reference_norm', 'prediction', 'reference']
|
| 146 |
+
preds_df = pull_columns(preds_df, cols=cols_order)
|
| 147 |
+
|
| 148 |
+
preds_df.to_excel(preds_fp, index=False)
|
| 149 |
else:
|
| 150 |
logger.info('save_predictions is False. will not save predictions to a file')
|
| 151 |
|