TestTranslator / scripts /run_whisper_finetuned.py
yujuanqin's picture
update scripts
42742c6
import argparse
import os
import time
from pathlib import Path
import csv
import torch
import librosa
from transformers import WhisperForConditionalGeneration, WhisperProcessor
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_audio(audio_path: str, sr: int = 16000):
# 读取音频并转成 16k 单声道 numpy float32
audio, _ = librosa.load(audio_path, sr=sr, mono=True)
return audio
def transcribe_file(
audio_path: str,
model,
processor,
language: str = "Chinese",
task: str = "transcribe",
timestamps: bool = False,
max_new_tokens: int = 255,
):
# 准备特征
audio = load_audio(audio_path, sr=16000)
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
# 放到设备
device = next(model.parameters()).device
input_features = inputs["input_features"].to(device)
# 生成
with torch.inference_mode(), torch.autocast(device_type="cuda", enabled=(device.type == "cuda")):
generated_ids = model.generate(
input_features=input_features,
max_new_tokens=max_new_tokens,
return_timestamps=timestamps, # 仅部分版本支持;不支持时自动忽略
)
# 解码
text = processor.tokenizer.batch_decode(generated_ids.cpu().numpy(), skip_special_tokens=True)
return text[0]
def main():
parser = argparse.ArgumentParser("Simple Whisper Inference")
parser.add_argument("--model_path", type=str, default="whisper-large-v3-turbo-finetune",
help="本地合并模型路径或HF模型名")
parser.add_argument("--input", type=str, required=True,
help="音频文件路径,或目录(将批量处理其中的音频)")
parser.add_argument("--language", type=str, default="Chinese",
help="语言(如 Chinese / English / zh / en)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="任务:转写或翻译")
parser.add_argument("--timestamps", action="store_true", help="是否返回时间戳(若模型与版本支持)")
parser.add_argument("--local_files_only", action="store_true", help="仅本地加载,不联网")
parser.add_argument("--batch_exts", type=str, default=".wav,.mp3,.flac,.m4a",
help="当 --input 是目录时,处理这些后缀的文件,逗号分隔")
args = parser.parse_args()
# 加载处理器 & 模型
processor = WhisperProcessor.from_pretrained(
args.model_path,
language=args.language,
task=args.task,
no_timestamps=not args.timestamps,
local_files_only=args.local_files_only,
)
model = WhisperForConditionalGeneration.from_pretrained(
args.model_path,
device_map="auto",
local_files_only=args.local_files_only,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
model.generation_config.language = args.language.lower()
model.generation_config.forced_decoder_ids = None
model.eval()
path = Path(args.input)
if path.is_file():
text = transcribe_file(
str(path), model, processor,
language=args.language, task=args.task, timestamps=args.timestamps
)
print(f"{path.name} -> {text}")
else:
# 目录批量
exts = {e.strip().lower() for e in args.batch_exts.split(",")}
files = [p for p in path.rglob("*") if p.suffix.lower() in exts]
if not files:
print("目录中未找到可处理的音频文件。")
return
for p in sorted(files):
try:
t0 = time.time()
text = transcribe_file(
str(p), model, processor,
language=args.language, task=args.task, timestamps=args.timestamps
)
t1 = time.time()
print(f"{p.name} -> {text}; time cost: {t1-t0}")
except Exception as e:
print(f"{p.name} -> 失败: {e}")
def load_model():
# model_path = "/Users/jeqin/Downloads/checkpoint-39000-full/whisper-large-v3-turbo-finetune"
model_path = "/Users/jeqin/Downloads/whisper-large-v3-turbo-finetune_1219"
lang = "zh"
t0 = time.time()
processor = WhisperProcessor.from_pretrained(
model_path,
language=lang,
task="transcribe",
no_timestamps=True,
local_files_only=True,
)
model = WhisperForConditionalGeneration.from_pretrained(
model_path,
device_map="mps",
local_files_only=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
model.generation_config.language = lang.lower()
model.generation_config.forced_decoder_ids = None
model.eval()
print("load model time: ", time.time() - t0)
return model, processor
def run_test_audios():
model, processor = load_model()
audios = Path("../test_data/audio_clips/")
rows = [["file_name", "inference_time", "inference_result"]]
for audio in sorted(audios.glob("*en-ac1-16k/*.wav")): # *s/randomforest*.wav"
try:
t0 = time.time()
text = transcribe_file(
str(audio), model, processor
)
t = time.time()-t0
print(f"{audio.name} -> {text}; time cost: {t}")
rows.append([f"{audio.parent.name}/{audio.name}", t, text])
except Exception as e:
print(f"{audio.name} -> 失败: {e}")
save_csv("csv/fine-tune_whisper-0901.csv", rows)
def run_recordings():
from scripts.asr_utils import get_origin_text_dict, get_text_distance
model, processor = load_model()
audios = Path("../test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
print(audio)
try:
t0 = time.time()
text = transcribe_file(
str(audio), model, processor
)
t = time.time()-t0
print(text)
print("inference time:", t)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff])
except Exception as e:
print(f"{audio.name} -> 失败: {e}")
save_csv("csv/fine-tune_whisper.csv", rows)
def run_test_dataset():
from test_data.audios import read_dataset
model, processor = load_model()
test_data = Path("../test_data/dataset/dataset.txt")
audio_parent = Path("../test_data/")
rows = [["file_name", "time", "inference_result"]]
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_dataset(test_data):
count += 1
print(f"processing {count}: {audio_path}")
t1 = time.time()
text = transcribe_file(
str(audio_parent/audio_path), model, processor
)
t = time.time() - t1
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_emilia():
from test_data.audios import read_emilia
model, processor = load_model()
parent = Path("../test_data/ZH-B000008")
result_list = []
count = 0
try:
for audio_path, sentence, duration in read_emilia(parent, count_limit=5000):
count += 1
print(f"processing {count}: {audio_path}")
t1 = time.time()
text = transcribe_file(
str(audio_path), model, processor
)
t = time.time() - t1
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
"duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetune_emilia_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_st():
from test_data.audios import read_st
model, processor = load_model()
# parent = Path("../test_data/ST-CMDS-20170001_1-OS")
result_list = []
count = 0
try:
for audio_path, sentence in read_st(count_limit=5000):
count += 1
print(f"processing {count}: {audio_path}")
t1 = time.time()
text = transcribe_file(
str(audio_path), model, processor
)
t = time.time() - t1
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
# "duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetune_st_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
def run_test_wenet():
from test_data.audios import read_wenet
model, processor = load_model()
result_list = []
count = 0
try:
for audio_path, sentence in read_wenet(count_limit=5000):
count += 1
print(f"processing {count}: {audio_path}")
t1 = time.time()
text = transcribe_file(
str(audio_path), model, processor
)
t = time.time() - t1
print("inference time:", t)
print(text)
result_list.append({
"index": count,
"audio_path": audio_path.name,
"reference": sentence,
# "duration": duration,
"inference_time": round(t, 3),
"inference_result": text
})
except Exception as e:
print(e)
except KeyboardInterrupt as e:
print(e)
import json
with open("csv/whisper_finetune_wenet_results.json", "w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# main()
# run_recordings()
# run_test_dataset()
# run_test_emilia()
run_test_wenet()