|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
import csv |
|
|
|
|
|
import torch |
|
|
import librosa |
|
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
|
|
|
def save_csv(file_path, rows): |
|
|
with open(file_path, "w", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerows(rows) |
|
|
print(f"write csv to {file_path}") |
|
|
|
|
|
|
|
|
def load_audio(audio_path: str, sr: int = 16000): |
|
|
|
|
|
audio, _ = librosa.load(audio_path, sr=sr, mono=True) |
|
|
return audio |
|
|
|
|
|
|
|
|
def transcribe_file( |
|
|
audio_path: str, |
|
|
model, |
|
|
processor, |
|
|
language: str = "Chinese", |
|
|
task: str = "transcribe", |
|
|
timestamps: bool = False, |
|
|
max_new_tokens: int = 255, |
|
|
): |
|
|
|
|
|
audio = load_audio(audio_path, sr=16000) |
|
|
inputs = processor(audio, sampling_rate=16000, return_tensors="pt") |
|
|
|
|
|
|
|
|
device = next(model.parameters()).device |
|
|
input_features = inputs["input_features"].to(device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(), torch.autocast(device_type="cuda", enabled=(device.type == "cuda")): |
|
|
generated_ids = model.generate( |
|
|
input_features=input_features, |
|
|
max_new_tokens=max_new_tokens, |
|
|
return_timestamps=timestamps, |
|
|
) |
|
|
|
|
|
|
|
|
text = processor.tokenizer.batch_decode(generated_ids.cpu().numpy(), skip_special_tokens=True) |
|
|
return text[0] |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser("Simple Whisper Inference") |
|
|
parser.add_argument("--model_path", type=str, default="whisper-large-v3-turbo-finetune", |
|
|
help="本地合并模型路径或HF模型名") |
|
|
parser.add_argument("--input", type=str, required=True, |
|
|
help="音频文件路径,或目录(将批量处理其中的音频)") |
|
|
parser.add_argument("--language", type=str, default="Chinese", |
|
|
help="语言(如 Chinese / English / zh / en)") |
|
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], |
|
|
help="任务:转写或翻译") |
|
|
parser.add_argument("--timestamps", action="store_true", help="是否返回时间戳(若模型与版本支持)") |
|
|
parser.add_argument("--local_files_only", action="store_true", help="仅本地加载,不联网") |
|
|
parser.add_argument("--batch_exts", type=str, default=".wav,.mp3,.flac,.m4a", |
|
|
help="当 --input 是目录时,处理这些后缀的文件,逗号分隔") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
processor = WhisperProcessor.from_pretrained( |
|
|
args.model_path, |
|
|
language=args.language, |
|
|
task=args.task, |
|
|
no_timestamps=not args.timestamps, |
|
|
local_files_only=args.local_files_only, |
|
|
) |
|
|
model = WhisperForConditionalGeneration.from_pretrained( |
|
|
args.model_path, |
|
|
device_map="auto", |
|
|
local_files_only=args.local_files_only, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
) |
|
|
|
|
|
model.generation_config.language = args.language.lower() |
|
|
model.generation_config.forced_decoder_ids = None |
|
|
model.eval() |
|
|
|
|
|
path = Path(args.input) |
|
|
if path.is_file(): |
|
|
text = transcribe_file( |
|
|
str(path), model, processor, |
|
|
language=args.language, task=args.task, timestamps=args.timestamps |
|
|
) |
|
|
print(f"{path.name} -> {text}") |
|
|
else: |
|
|
|
|
|
exts = {e.strip().lower() for e in args.batch_exts.split(",")} |
|
|
files = [p for p in path.rglob("*") if p.suffix.lower() in exts] |
|
|
if not files: |
|
|
print("目录中未找到可处理的音频文件。") |
|
|
return |
|
|
for p in sorted(files): |
|
|
try: |
|
|
t0 = time.time() |
|
|
text = transcribe_file( |
|
|
str(p), model, processor, |
|
|
language=args.language, task=args.task, timestamps=args.timestamps |
|
|
) |
|
|
t1 = time.time() |
|
|
print(f"{p.name} -> {text}; time cost: {t1-t0}") |
|
|
except Exception as e: |
|
|
print(f"{p.name} -> 失败: {e}") |
|
|
|
|
|
def load_model(): |
|
|
|
|
|
model_path = "/Users/jeqin/Downloads/whisper-large-v3-turbo-finetune_1219" |
|
|
lang = "zh" |
|
|
t0 = time.time() |
|
|
processor = WhisperProcessor.from_pretrained( |
|
|
model_path, |
|
|
language=lang, |
|
|
task="transcribe", |
|
|
no_timestamps=True, |
|
|
local_files_only=True, |
|
|
) |
|
|
model = WhisperForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
device_map="mps", |
|
|
local_files_only=True, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
) |
|
|
|
|
|
model.generation_config.language = lang.lower() |
|
|
model.generation_config.forced_decoder_ids = None |
|
|
model.eval() |
|
|
print("load model time: ", time.time() - t0) |
|
|
return model, processor |
|
|
|
|
|
def run_test_audios(): |
|
|
model, processor = load_model() |
|
|
audios = Path("../test_data/audio_clips/") |
|
|
rows = [["file_name", "inference_time", "inference_result"]] |
|
|
for audio in sorted(audios.glob("*en-ac1-16k/*.wav")): |
|
|
try: |
|
|
t0 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio), model, processor |
|
|
) |
|
|
|
|
|
t = time.time()-t0 |
|
|
print(f"{audio.name} -> {text}; time cost: {t}") |
|
|
rows.append([f"{audio.parent.name}/{audio.name}", t, text]) |
|
|
except Exception as e: |
|
|
print(f"{audio.name} -> 失败: {e}") |
|
|
save_csv("csv/fine-tune_whisper-0901.csv", rows) |
|
|
|
|
|
def run_recordings(): |
|
|
from scripts.asr_utils import get_origin_text_dict, get_text_distance |
|
|
model, processor = load_model() |
|
|
audios = Path("../test_data/recordings/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
original = get_origin_text_dict() |
|
|
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): |
|
|
print(audio) |
|
|
try: |
|
|
t0 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio), model, processor |
|
|
) |
|
|
t = time.time()-t0 |
|
|
print(text) |
|
|
print("inference time:", t) |
|
|
d, nd, diff = get_text_distance(original[audio.stem], text) |
|
|
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) |
|
|
except Exception as e: |
|
|
print(f"{audio.name} -> 失败: {e}") |
|
|
save_csv("csv/fine-tune_whisper.csv", rows) |
|
|
|
|
|
|
|
|
def run_test_dataset(): |
|
|
from test_data.audios import read_dataset |
|
|
model, processor = load_model() |
|
|
test_data = Path("../test_data/dataset/dataset.txt") |
|
|
audio_parent = Path("../test_data/") |
|
|
rows = [["file_name", "time", "inference_result"]] |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence, duration in read_dataset(test_data): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio_parent/audio_path), model, processor |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path, |
|
|
"reference": sentence, |
|
|
"duration": duration, |
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def run_test_emilia(): |
|
|
from test_data.audios import read_emilia |
|
|
model, processor = load_model() |
|
|
parent = Path("../test_data/ZH-B000008") |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio_path), model, processor |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
"duration": duration, |
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_finetune_emilia_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
def run_test_st(): |
|
|
from test_data.audios import read_st |
|
|
model, processor = load_model() |
|
|
|
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence in read_st(count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio_path), model, processor |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
|
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_finetune_st_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def run_test_wenet(): |
|
|
from test_data.audios import read_wenet |
|
|
model, processor = load_model() |
|
|
result_list = [] |
|
|
count = 0 |
|
|
try: |
|
|
for audio_path, sentence in read_wenet(count_limit=5000): |
|
|
count += 1 |
|
|
print(f"processing {count}: {audio_path}") |
|
|
|
|
|
t1 = time.time() |
|
|
text = transcribe_file( |
|
|
str(audio_path), model, processor |
|
|
) |
|
|
t = time.time() - t1 |
|
|
print("inference time:", t) |
|
|
print(text) |
|
|
result_list.append({ |
|
|
"index": count, |
|
|
"audio_path": audio_path.name, |
|
|
"reference": sentence, |
|
|
|
|
|
"inference_time": round(t, 3), |
|
|
"inference_result": text |
|
|
}) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
except KeyboardInterrupt as e: |
|
|
print(e) |
|
|
import json |
|
|
with open("csv/whisper_finetune_wenet_results.json", "w", encoding="utf-8") as f: |
|
|
json.dump(result_list, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_test_wenet() |
|
|
|