|
|
import argparse
|
|
|
|
|
|
from transformers import pipeline
|
|
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
|
|
from datasets import load_dataset, Audio
|
|
|
import evaluate
|
|
|
|
|
|
wer_metric = evaluate.load("wer")
|
|
|
|
|
|
|
|
|
def is_target_text_in_range(ref):
|
|
|
if ref.strip() == "ignore time segment in scoring":
|
|
|
return False
|
|
|
else:
|
|
|
return ref.strip() != ""
|
|
|
|
|
|
|
|
|
def get_text(sample):
|
|
|
if "text" in sample:
|
|
|
return sample["text"]
|
|
|
elif "sentence" in sample:
|
|
|
return sample["sentence"]
|
|
|
elif "normalized_text" in sample:
|
|
|
return sample["normalized_text"]
|
|
|
elif "transcript" in sample:
|
|
|
return sample["transcript"]
|
|
|
elif "transcription" in sample:
|
|
|
return sample["transcription"]
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
|
|
|
".join{sample.keys()}. Ensure a text column name is present in the dataset."
|
|
|
)
|
|
|
|
|
|
|
|
|
whisper_norm = BasicTextNormalizer()
|
|
|
|
|
|
|
|
|
def normalise(batch):
|
|
|
batch["norm_text"] = whisper_norm(get_text(batch))
|
|
|
return batch
|
|
|
|
|
|
|
|
|
def data(dataset):
|
|
|
for i, item in enumerate(dataset):
|
|
|
yield {**item["audio"], "reference": item["norm_text"]}
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
batch_size = args.batch_size
|
|
|
whisper_asr = pipeline(
|
|
|
"automatic-speech-recognition", model=args.model_id, device=args.device
|
|
|
)
|
|
|
|
|
|
whisper_asr.model.config.forced_decoder_ids = (
|
|
|
whisper_asr.tokenizer.get_decoder_prompt_ids(
|
|
|
language=args.language, task="transcribe"
|
|
|
)
|
|
|
)
|
|
|
|
|
|
dataset = load_dataset(
|
|
|
args.dataset,
|
|
|
args.config,
|
|
|
split=args.split,
|
|
|
streaming=args.streaming,
|
|
|
use_auth_token=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
dataset = dataset.take(args.max_eval_samples)
|
|
|
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
|
|
dataset = dataset.map(normalise)
|
|
|
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
|
|
|
|
|
|
predictions = []
|
|
|
references = []
|
|
|
|
|
|
|
|
|
for out in whisper_asr(data(dataset), batch_size=batch_size):
|
|
|
predictions.append(whisper_norm(out["text"]))
|
|
|
references.append(out["reference"][0])
|
|
|
|
|
|
wer = wer_metric.compute(references=references, predictions=predictions)
|
|
|
wer = round(100 * wer, 2)
|
|
|
|
|
|
print("WER:", wer)
|
|
|
evaluate.push_to_hub(
|
|
|
model_id=args.model_id,
|
|
|
metric_value=wer,
|
|
|
metric_type="wer",
|
|
|
metric_name="WER",
|
|
|
dataset_name=args.dataset,
|
|
|
dataset_type=args.dataset,
|
|
|
dataset_split=args.split,
|
|
|
dataset_config=args.config,
|
|
|
task_type="automatic-speech-recognition",
|
|
|
task_name="Automatic Speech Recognition"
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--model_id",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Model identifier. Should be loadable with 🤗 Transformers",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--dataset",
|
|
|
type=str,
|
|
|
default="mozilla-foundation/common_voice_11_0",
|
|
|
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--config",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--split",
|
|
|
type=str,
|
|
|
default="test",
|
|
|
help="Split of the dataset. *E.g.* `'test'`",
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--device",
|
|
|
type=int,
|
|
|
default=-1,
|
|
|
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--batch_size",
|
|
|
type=int,
|
|
|
default=16,
|
|
|
help="Number of samples to go through each streamed batch.",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--max_eval_samples",
|
|
|
type=int,
|
|
|
default=None,
|
|
|
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--streaming",
|
|
|
type=bool,
|
|
|
default=True,
|
|
|
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--language",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Two letter language code for the transcription language, e.g. use 'en' for English.",
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
main(args) |