TestTranslator / scripts /run_whisper_finetuned.py
yujuanqin's picture
update path to relative
8a27f7c
raw
history blame
6.7 kB
import argparse
import os
import time
from pathlib import Path
import csv
import numpy as np
import torch
import librosa
from transformers import WhisperForConditionalGeneration, WhisperProcessor
def save_csv(file_path, rows):
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerows(rows)
print(f"write csv to {file_path}")
def load_audio(audio_path: str, sr: int = 16000):
# 读取音频并转成 16k 单声道 numpy float32
audio, _ = librosa.load(audio_path, sr=sr, mono=True)
return audio
def transcribe_file(
audio_path: str,
model,
processor,
language: str = "Chinese",
task: str = "transcribe",
timestamps: bool = False,
max_new_tokens: int = 255,
):
# 准备特征
audio = load_audio(audio_path, sr=16000)
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
# 放到设备
device = next(model.parameters()).device
input_features = inputs["input_features"].to(device)
# 生成
with torch.inference_mode(), torch.autocast(device_type="cuda", enabled=(device.type == "cuda")):
generated_ids = model.generate(
input_features=input_features,
max_new_tokens=max_new_tokens,
return_timestamps=timestamps, # 仅部分版本支持;不支持时自动忽略
)
# 解码
text = processor.tokenizer.batch_decode(generated_ids.cpu().numpy(), skip_special_tokens=True)
return text[0]
def main():
parser = argparse.ArgumentParser("Simple Whisper Inference")
parser.add_argument("--model_path", type=str, default="whisper-large-v3-turbo-finetune",
help="本地合并模型路径或HF模型名")
parser.add_argument("--input", type=str, required=True,
help="音频文件路径,或目录(将批量处理其中的音频)")
parser.add_argument("--language", type=str, default="Chinese",
help="语言(如 Chinese / English / zh / en)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="任务:转写或翻译")
parser.add_argument("--timestamps", action="store_true", help="是否返回时间戳(若模型与版本支持)")
parser.add_argument("--local_files_only", action="store_true", help="仅本地加载,不联网")
parser.add_argument("--batch_exts", type=str, default=".wav,.mp3,.flac,.m4a",
help="当 --input 是目录时,处理这些后缀的文件,逗号分隔")
args = parser.parse_args()
# 加载处理器 & 模型
processor = WhisperProcessor.from_pretrained(
args.model_path,
language=args.language,
task=args.task,
no_timestamps=not args.timestamps,
local_files_only=args.local_files_only,
)
model = WhisperForConditionalGeneration.from_pretrained(
args.model_path,
device_map="auto",
local_files_only=args.local_files_only,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
model.generation_config.language = args.language.lower()
model.generation_config.forced_decoder_ids = None
model.eval()
path = Path(args.input)
if path.is_file():
text = transcribe_file(
str(path), model, processor,
language=args.language, task=args.task, timestamps=args.timestamps
)
print(f"{path.name} -> {text}")
else:
# 目录批量
exts = {e.strip().lower() for e in args.batch_exts.split(",")}
files = [p for p in path.rglob("*") if p.suffix.lower() in exts]
if not files:
print("目录中未找到可处理的音频文件。")
return
for p in sorted(files):
try:
t0 = time.time()
text = transcribe_file(
str(p), model, processor,
language=args.language, task=args.task, timestamps=args.timestamps
)
t1 = time.time()
print(f"{p.name} -> {text}; time cost: {t1-t0}")
except Exception as e:
print(f"{p.name} -> 失败: {e}")
def load_model():
model_path = "/Users/jeqin/Downloads/whisper-large-v3-turbo-finetune-0901"
lang = "zh"
t0 = time.time()
processor = WhisperProcessor.from_pretrained(
model_path,
language=lang,
task="transcribe",
no_timestamps=True,
local_files_only=True,
)
model = WhisperForConditionalGeneration.from_pretrained(
model_path,
device_map="mps",
local_files_only=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
model.generation_config.language = lang.lower()
model.generation_config.forced_decoder_ids = None
model.eval()
print("load model time: ", time.time() - t0)
return model, processor
def run_test_audios():
model, processor = load_model()
audios = Path("../tests/test_data/test_audios/")
rows = [["file_name", "inference_time", "inference_result"]]
for audio in sorted(audios.glob("*en-ac1-16k/*.wav")): # *s/randomforest*.wav"
try:
t0 = time.time()
text = transcribe_file(
str(audio), model, processor
)
t = time.time()-t0
print(f"{audio.name} -> {text}; time cost: {t}")
rows.append([f"{audio.parent.name}/{audio.name}", t, text])
except Exception as e:
print(f"{audio.name} -> 失败: {e}")
save_csv("csv/fine-tune_whisper-0901.csv", rows)
def run_recordings():
from scripts.asr_utils import get_origin_text_dict, get_text_distance
model, processor = load_model()
audios = Path("../tests/test_data/recordings/")
rows = [["file_name", "time", "inference_result"]]
original = get_origin_text_dict()
for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)):
print(audio)
try:
t0 = time.time()
text = transcribe_file(
str(audio), model, processor
)
t = time.time()-t0
print(text)
print("inference time:", t)
d, nd, diff = get_text_distance(original[audio.stem], text)
rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff])
except Exception as e:
print(f"{audio.name} -> 失败: {e}")
save_csv("csv/fine-tune_whisper.csv", rows)
if __name__ == "__main__":
# main()
run_recordings()