File size: 5,327 Bytes
4095301 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import glob
import os
import os.path as osp
import numpy as np
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from tqdm import tqdm
from utils.evaluation import MixErrorRate
from utils.transcript_readers import read_vtt
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from utils.model_utils import mix_language_embeddings
def eval_longform_single(audio_fpath, trans_fpath, pipe, normalizer=None):
segments = read_vtt(trans_fpath)
ref = " ".join([seg[-1] for seg in segments])
output = pipe(audio_fpath)
# print(output)
hyp = " ".join(chunk["text"] for chunk in output["chunks"])
# print(hyp)
if normalizer:
hyp = normalizer(hyp)
ref = normalizer(ref)
return hyp, ref
def main(args):
device = f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = args.model_id
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
processor = AutoProcessor.from_pretrained(model_id)
if args.mix_lang_emb:
print("Mixing language embeddings...")
if not "large" in model_id:
raise ValueError("Language embedding mixing is only supported for large models.")
zh_emb = model.model.decoder.embed_tokens.weight[50260]
en_emb = model.model.decoder.embed_tokens.weight[50259]
print(f"zh: {zh_emb}, {zh_emb.shape}")
print(f"en: {en_emb}, {en_emb.shape}")
model = mix_language_embeddings(model, processor.tokenizer, languages=['zh', 'en'])
model.to(device)
model.eval()
# print(model)
zh_emb = model.model.decoder.embed_tokens.weight[50260]
en_emb = model.model.decoder.embed_tokens.weight[50259]
print(f"zh: {zh_emb}, {zh_emb.shape}")
print(f"en: {en_emb}, {en_emb.shape}")
generate_kwargs = {
'language': args.language,
'task': 'transcribe',
'return_timestamps': True,
}
model.generation_config.update(**generate_kwargs)
processor.tokenizer.set_prefix_tokens(language=args.language, task="transcribe", predict_timestamps=True)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
return_timestamps=True,
generate_kwargs=generate_kwargs,
max_new_tokens=256,
torch_dtype=torch_dtype,
device=device,
)
testing_data_root = args.testing_data_root
audio_fpaths = glob.glob(f"{testing_data_root}/**/*.wav", recursive=True)
trans_fpaths = list(map(lambda x: x.replace(".wav", ".srt"), audio_fpaths))
if args.test:
audio_fpaths = audio_fpaths[:2]
trans_fpaths = trans_fpaths[:2]
normalizer = None
if args.normalize:
normalizer = BasicTextNormalizer()
hyps, refs = [], []
for audio_fpath, trans_fpath in tqdm(zip(audio_fpaths, trans_fpaths), total=len(audio_fpaths), desc="Evaluating..."):
hyp, ref = eval_longform_single(audio_fpath, trans_fpath, pipe, normalizer=normalizer)
hyps.append(hyp)
refs.append(ref)
metric = MixErrorRate()
mer = metric.compute(hyps, refs)
print(f"{model_id}: {mer}")
if "large" in model_id:
model_id = f"/root/distil-whisper/corpus/output/{model_id}"
os.makedirs(model_id, exist_ok=True)
output_dir = osp.join(model_id, "cool_test_real_longform")
os.makedirs(output_dir, exist_ok=True)
result_fpath = osp.join(output_dir, f"result_{args.language}_{mer:.4f}.tsv")
config_output_fpath = osp.join(output_dir, "cool_test_config.txt")
with open(result_fpath, 'w') as fw:
print("audio_fpath\thyp\tref", file=fw)
for hyp, ref, audio_f, trans_f in zip(hyps, refs, audio_fpaths, trans_fpaths):
print(f"{audio_f}\t{hyp}\t{ref}", file=fw)
with open(config_output_fpath, 'a') as fw:
print(f"MER: {mer}", file=fw)
print(f"Model ID: {model_id}", file=fw)
print(f"Testing data root: {testing_data_root}", file=fw)
print(f"Device ID: {args.device_id}", file=fw)
print(f"Test: {args.test}", file=fw)
print(f"Normalize: {args.normalize}", file=fw)
print(f"Mix lang emb: {args.mix_lang_emb}", file=fw)
print(f"Language: {args.language}", file=fw)
print(f"Generated kwargs: {generate_kwargs}", file=fw)
# log zh, en embeddings in np format
np.save(osp.join(model_id, "zh_emb.npy"), zh_emb.cpu().detach().numpy())
np.save(osp.join(model_id, "en_emb.npy"), en_emb.cpu().detach().numpy())
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, required=True)
parser.add_argument("--testing_data_root", type=str, required=True)
parser.add_argument("--language", type=str, default="zh")
parser.add_argument("--device_id", default=0, type=int)
parser.add_argument("--test", action="store_true")
parser.add_argument("--normalize", action="store_true")
parser.add_argument("--mix_lang_emb", action="store_true")
args = parser.parse_args()
main(args)
|