|
|
import torch |
|
|
import glob |
|
|
import os |
|
|
import os.path as osp |
|
|
import numpy as np |
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
|
from tqdm import tqdm |
|
|
from utils.evaluation import MixErrorRate |
|
|
from utils.transcript_readers import read_vtt |
|
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
|
from utils.model_utils import mix_language_embeddings |
|
|
|
|
|
def eval_longform_single(audio_fpath, trans_fpath, pipe, normalizer=None): |
|
|
segments = read_vtt(trans_fpath) |
|
|
ref = " ".join([seg[-1] for seg in segments]) |
|
|
output = pipe(audio_fpath) |
|
|
|
|
|
hyp = " ".join(chunk["text"] for chunk in output["chunks"]) |
|
|
|
|
|
if normalizer: |
|
|
hyp = normalizer(hyp) |
|
|
ref = normalizer(ref) |
|
|
return hyp, ref |
|
|
|
|
|
def main(args): |
|
|
device = f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
model_id = args.model_id |
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
|
|
) |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
if args.mix_lang_emb: |
|
|
print("Mixing language embeddings...") |
|
|
if not "large" in model_id: |
|
|
raise ValueError("Language embedding mixing is only supported for large models.") |
|
|
zh_emb = model.model.decoder.embed_tokens.weight[50260] |
|
|
en_emb = model.model.decoder.embed_tokens.weight[50259] |
|
|
print(f"zh: {zh_emb}, {zh_emb.shape}") |
|
|
print(f"en: {en_emb}, {en_emb.shape}") |
|
|
model = mix_language_embeddings(model, processor.tokenizer, languages=['zh', 'en']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
zh_emb = model.model.decoder.embed_tokens.weight[50260] |
|
|
en_emb = model.model.decoder.embed_tokens.weight[50259] |
|
|
print(f"zh: {zh_emb}, {zh_emb.shape}") |
|
|
print(f"en: {en_emb}, {en_emb.shape}") |
|
|
|
|
|
generate_kwargs = { |
|
|
'language': args.language, |
|
|
'task': 'transcribe', |
|
|
'return_timestamps': True, |
|
|
} |
|
|
model.generation_config.update(**generate_kwargs) |
|
|
processor.tokenizer.set_prefix_tokens(language=args.language, task="transcribe", predict_timestamps=True) |
|
|
|
|
|
pipe = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=model, |
|
|
tokenizer=processor.tokenizer, |
|
|
feature_extractor=processor.feature_extractor, |
|
|
return_timestamps=True, |
|
|
generate_kwargs=generate_kwargs, |
|
|
max_new_tokens=256, |
|
|
torch_dtype=torch_dtype, |
|
|
device=device, |
|
|
) |
|
|
testing_data_root = args.testing_data_root |
|
|
audio_fpaths = glob.glob(f"{testing_data_root}/**/*.wav", recursive=True) |
|
|
trans_fpaths = list(map(lambda x: x.replace(".wav", ".srt"), audio_fpaths)) |
|
|
if args.test: |
|
|
audio_fpaths = audio_fpaths[:2] |
|
|
trans_fpaths = trans_fpaths[:2] |
|
|
normalizer = None |
|
|
if args.normalize: |
|
|
normalizer = BasicTextNormalizer() |
|
|
hyps, refs = [], [] |
|
|
for audio_fpath, trans_fpath in tqdm(zip(audio_fpaths, trans_fpaths), total=len(audio_fpaths), desc="Evaluating..."): |
|
|
hyp, ref = eval_longform_single(audio_fpath, trans_fpath, pipe, normalizer=normalizer) |
|
|
hyps.append(hyp) |
|
|
refs.append(ref) |
|
|
metric = MixErrorRate() |
|
|
mer = metric.compute(hyps, refs) |
|
|
print(f"{model_id}: {mer}") |
|
|
if "large" in model_id: |
|
|
model_id = f"/root/distil-whisper/corpus/output/{model_id}" |
|
|
os.makedirs(model_id, exist_ok=True) |
|
|
output_dir = osp.join(model_id, "cool_test_real_longform") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
result_fpath = osp.join(output_dir, f"result_{args.language}_{mer:.4f}.tsv") |
|
|
config_output_fpath = osp.join(output_dir, "cool_test_config.txt") |
|
|
with open(result_fpath, 'w') as fw: |
|
|
print("audio_fpath\thyp\tref", file=fw) |
|
|
for hyp, ref, audio_f, trans_f in zip(hyps, refs, audio_fpaths, trans_fpaths): |
|
|
print(f"{audio_f}\t{hyp}\t{ref}", file=fw) |
|
|
with open(config_output_fpath, 'a') as fw: |
|
|
print(f"MER: {mer}", file=fw) |
|
|
print(f"Model ID: {model_id}", file=fw) |
|
|
print(f"Testing data root: {testing_data_root}", file=fw) |
|
|
print(f"Device ID: {args.device_id}", file=fw) |
|
|
print(f"Test: {args.test}", file=fw) |
|
|
print(f"Normalize: {args.normalize}", file=fw) |
|
|
print(f"Mix lang emb: {args.mix_lang_emb}", file=fw) |
|
|
print(f"Language: {args.language}", file=fw) |
|
|
print(f"Generated kwargs: {generate_kwargs}", file=fw) |
|
|
|
|
|
np.save(osp.join(model_id, "zh_emb.npy"), zh_emb.cpu().detach().numpy()) |
|
|
np.save(osp.join(model_id, "en_emb.npy"), en_emb.cpu().detach().numpy()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_id", type=str, required=True) |
|
|
parser.add_argument("--testing_data_root", type=str, required=True) |
|
|
parser.add_argument("--language", type=str, default="zh") |
|
|
parser.add_argument("--device_id", default=0, type=int) |
|
|
parser.add_argument("--test", action="store_true") |
|
|
parser.add_argument("--normalize", action="store_true") |
|
|
parser.add_argument("--mix_lang_emb", action="store_true") |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|