levi default wer norm
Browse files- benchmark_utils.py +8 -50
benchmark_utils.py
CHANGED
|
@@ -66,8 +66,9 @@ def ASRmanifest(
|
|
| 66 |
with torch.no_grad():
|
| 67 |
with autocast():
|
| 68 |
try:
|
| 69 |
-
result = asr_pipeline(audiofile)
|
| 70 |
asrtext = result['text']
|
|
|
|
| 71 |
except (FileNotFoundError, ValueError) as e:
|
| 72 |
print(f'SKIPPED: {audiofile}')
|
| 73 |
continue
|
|
@@ -77,49 +78,6 @@ def ASRmanifest(
|
|
| 77 |
compute_time = (et-st)
|
| 78 |
print(f'...transcription complete in {compute_time:.1f} sec')
|
| 79 |
|
| 80 |
-
def load_model(
|
| 81 |
-
model_path:str,
|
| 82 |
-
language='english',
|
| 83 |
-
use_int8 = False,
|
| 84 |
-
device_map='auto'):
|
| 85 |
-
|
| 86 |
-
warnings.filterwarnings("ignore")
|
| 87 |
-
transformers.utils.logging.set_verbosity_error()
|
| 88 |
-
|
| 89 |
-
try:
|
| 90 |
-
model = WhisperForConditionalGeneration.from_pretrained(
|
| 91 |
-
model_path,
|
| 92 |
-
load_in_8bit=use_int8,
|
| 93 |
-
device_map=device_map,
|
| 94 |
-
use_cache=False,
|
| 95 |
-
)
|
| 96 |
-
try:
|
| 97 |
-
processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe")
|
| 98 |
-
except OSError:
|
| 99 |
-
print('missing tokenizer and preprocessor config files in save dir, checking directory above...')
|
| 100 |
-
processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe")
|
| 101 |
-
|
| 102 |
-
except OSError as e:
|
| 103 |
-
print(f'{e}: possibly missing model or config file in model path. Will check for adapter...')
|
| 104 |
-
# check if PEFT
|
| 105 |
-
if os.path.isdir(os.path.join(model_path , "adapter_model")):
|
| 106 |
-
print('found adapter...loading PEFT model')
|
| 107 |
-
# checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
|
| 108 |
-
peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model"))
|
| 109 |
-
print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}')
|
| 110 |
-
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
|
| 111 |
-
load_in_8bit=use_int8,
|
| 112 |
-
device_map=device_map,
|
| 113 |
-
use_cache=False,
|
| 114 |
-
)
|
| 115 |
-
model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model"))
|
| 116 |
-
model = model.merge_and_unload()
|
| 117 |
-
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe")
|
| 118 |
-
else:
|
| 119 |
-
raise e
|
| 120 |
-
model.eval()
|
| 121 |
-
return(model, processor)
|
| 122 |
-
|
| 123 |
def prepare_pipeline(model_path, generate_opts):
|
| 124 |
"""Prepare a pipeline for ASR inference
|
| 125 |
Args:
|
|
@@ -128,16 +86,16 @@ def prepare_pipeline(model_path, generate_opts):
|
|
| 128 |
Returns:
|
| 129 |
pipeline: ASR pipeline
|
| 130 |
"""
|
| 131 |
-
|
| 132 |
-
model_path=model_path)
|
| 133 |
|
| 134 |
asr_pipeline = pipeline(
|
| 135 |
"automatic-speech-recognition",
|
| 136 |
-
model=
|
| 137 |
tokenizer=processor.tokenizer,
|
| 138 |
feature_extractor=processor.feature_extractor,
|
| 139 |
generate_kwargs=generate_opts,
|
| 140 |
-
|
|
|
|
| 141 |
return asr_pipeline
|
| 142 |
|
| 143 |
#%% WER evaluation functions
|
|
@@ -285,7 +243,7 @@ def wer_from_df(
|
|
| 285 |
hypcol='hyp',
|
| 286 |
return_alignments=False,
|
| 287 |
normalise = True,
|
| 288 |
-
text_norm_method='
|
| 289 |
printout=True):
|
| 290 |
"""Compute WER from a dataframe containing a ref col and a hyp col
|
| 291 |
WER is computed on the edit operation counts over the whole df,
|
|
@@ -338,7 +296,7 @@ def wer_from_csv(
|
|
| 338 |
hypcol='hyp',
|
| 339 |
return_alignments=False,
|
| 340 |
normalise = True,
|
| 341 |
-
text_norm_method='
|
| 342 |
printout=True):
|
| 343 |
|
| 344 |
res = pd.read_csv(csv_path).astype(str)
|
|
|
|
| 66 |
with torch.no_grad():
|
| 67 |
with autocast():
|
| 68 |
try:
|
| 69 |
+
result = asr_pipeline(audiofile )
|
| 70 |
asrtext = result['text']
|
| 71 |
+
asr_pipeline.call_count = 0
|
| 72 |
except (FileNotFoundError, ValueError) as e:
|
| 73 |
print(f'SKIPPED: {audiofile}')
|
| 74 |
continue
|
|
|
|
| 78 |
compute_time = (et-st)
|
| 79 |
print(f'...transcription complete in {compute_time:.1f} sec')
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def prepare_pipeline(model_path, generate_opts):
|
| 82 |
"""Prepare a pipeline for ASR inference
|
| 83 |
Args:
|
|
|
|
| 86 |
Returns:
|
| 87 |
pipeline: ASR pipeline
|
| 88 |
"""
|
| 89 |
+
processor = WhisperProcessor.from_pretrained(model_path)
|
|
|
|
| 90 |
|
| 91 |
asr_pipeline = pipeline(
|
| 92 |
"automatic-speech-recognition",
|
| 93 |
+
model=model_path,
|
| 94 |
tokenizer=processor.tokenizer,
|
| 95 |
feature_extractor=processor.feature_extractor,
|
| 96 |
generate_kwargs=generate_opts,
|
| 97 |
+
model_kwargs={"load_in_8bit": False},
|
| 98 |
+
device_map='auto')
|
| 99 |
return asr_pipeline
|
| 100 |
|
| 101 |
#%% WER evaluation functions
|
|
|
|
| 243 |
hypcol='hyp',
|
| 244 |
return_alignments=False,
|
| 245 |
normalise = True,
|
| 246 |
+
text_norm_method='levi',
|
| 247 |
printout=True):
|
| 248 |
"""Compute WER from a dataframe containing a ref col and a hyp col
|
| 249 |
WER is computed on the edit operation counts over the whole df,
|
|
|
|
| 296 |
hypcol='hyp',
|
| 297 |
return_alignments=False,
|
| 298 |
normalise = True,
|
| 299 |
+
text_norm_method='levi' ,
|
| 300 |
printout=True):
|
| 301 |
|
| 302 |
res = pd.read_csv(csv_path).astype(str)
|