import argparse import os import time from pathlib import Path import csv import torch import librosa from transformers import WhisperForConditionalGeneration, WhisperProcessor def save_csv(file_path, rows): with open(file_path, "w", encoding="utf-8") as f: writer = csv.writer(f) writer.writerows(rows) print(f"write csv to {file_path}") def load_audio(audio_path: str, sr: int = 16000): # 读取音频并转成 16k 单声道 numpy float32 audio, _ = librosa.load(audio_path, sr=sr, mono=True) return audio def transcribe_file( audio_path: str, model, processor, language: str = "Chinese", task: str = "transcribe", timestamps: bool = False, max_new_tokens: int = 255, ): # 准备特征 audio = load_audio(audio_path, sr=16000) inputs = processor(audio, sampling_rate=16000, return_tensors="pt") # 放到设备 device = next(model.parameters()).device input_features = inputs["input_features"].to(device) # 生成 with torch.inference_mode(), torch.autocast(device_type="cuda", enabled=(device.type == "cuda")): generated_ids = model.generate( input_features=input_features, max_new_tokens=max_new_tokens, return_timestamps=timestamps, # 仅部分版本支持;不支持时自动忽略 ) # 解码 text = processor.tokenizer.batch_decode(generated_ids.cpu().numpy(), skip_special_tokens=True) return text[0] def main(): parser = argparse.ArgumentParser("Simple Whisper Inference") parser.add_argument("--model_path", type=str, default="whisper-large-v3-turbo-finetune", help="本地合并模型路径或HF模型名") parser.add_argument("--input", type=str, required=True, help="音频文件路径,或目录(将批量处理其中的音频)") parser.add_argument("--language", type=str, default="Chinese", help="语言(如 Chinese / English / zh / en)") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="任务:转写或翻译") parser.add_argument("--timestamps", action="store_true", help="是否返回时间戳(若模型与版本支持)") parser.add_argument("--local_files_only", action="store_true", help="仅本地加载,不联网") parser.add_argument("--batch_exts", type=str, default=".wav,.mp3,.flac,.m4a", help="当 --input 是目录时,处理这些后缀的文件,逗号分隔") args = parser.parse_args() # 加载处理器 & 模型 processor = WhisperProcessor.from_pretrained( args.model_path, language=args.language, task=args.task, no_timestamps=not args.timestamps, local_files_only=args.local_files_only, ) model = WhisperForConditionalGeneration.from_pretrained( args.model_path, device_map="auto", local_files_only=args.local_files_only, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) model.generation_config.language = args.language.lower() model.generation_config.forced_decoder_ids = None model.eval() path = Path(args.input) if path.is_file(): text = transcribe_file( str(path), model, processor, language=args.language, task=args.task, timestamps=args.timestamps ) print(f"{path.name} -> {text}") else: # 目录批量 exts = {e.strip().lower() for e in args.batch_exts.split(",")} files = [p for p in path.rglob("*") if p.suffix.lower() in exts] if not files: print("目录中未找到可处理的音频文件。") return for p in sorted(files): try: t0 = time.time() text = transcribe_file( str(p), model, processor, language=args.language, task=args.task, timestamps=args.timestamps ) t1 = time.time() print(f"{p.name} -> {text}; time cost: {t1-t0}") except Exception as e: print(f"{p.name} -> 失败: {e}") def load_model(): # model_path = "/Users/jeqin/Downloads/checkpoint-39000-full/whisper-large-v3-turbo-finetune" model_path = "/Users/jeqin/Downloads/whisper-large-v3-turbo-finetune_1219" lang = "zh" t0 = time.time() processor = WhisperProcessor.from_pretrained( model_path, language=lang, task="transcribe", no_timestamps=True, local_files_only=True, ) model = WhisperForConditionalGeneration.from_pretrained( model_path, device_map="mps", local_files_only=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) model.generation_config.language = lang.lower() model.generation_config.forced_decoder_ids = None model.eval() print("load model time: ", time.time() - t0) return model, processor def run_test_audios(): model, processor = load_model() audios = Path("../test_data/audio_clips/") rows = [["file_name", "inference_time", "inference_result"]] for audio in sorted(audios.glob("*en-ac1-16k/*.wav")): # *s/randomforest*.wav" try: t0 = time.time() text = transcribe_file( str(audio), model, processor ) t = time.time()-t0 print(f"{audio.name} -> {text}; time cost: {t}") rows.append([f"{audio.parent.name}/{audio.name}", t, text]) except Exception as e: print(f"{audio.name} -> 失败: {e}") save_csv("csv/fine-tune_whisper-0901.csv", rows) def run_recordings(): from scripts.asr_utils import get_origin_text_dict, get_text_distance model, processor = load_model() audios = Path("../test_data/recordings/") rows = [["file_name", "time", "inference_result"]] original = get_origin_text_dict() for audio in sorted(audios.glob("*.wav"), key=lambda x: int(x.stem)): print(audio) try: t0 = time.time() text = transcribe_file( str(audio), model, processor ) t = time.time()-t0 print(text) print("inference time:", t) d, nd, diff = get_text_distance(original[audio.stem], text) rows.append([audio.name, round(t, 3), text, d, round(nd,3), diff]) except Exception as e: print(f"{audio.name} -> 失败: {e}") save_csv("csv/fine-tune_whisper.csv", rows) def run_test_dataset(): from test_data.audios import read_dataset model, processor = load_model() test_data = Path("../test_data/AIShell/dataset/dataset.txt") audio_parent = Path("../test_data/") rows = [["file_name", "time", "inference_result"]] result_list = [] count = 0 try: for audio_path, sentence, duration in read_dataset(test_data): count += 1 print(f"processing {count}: {audio_path}") t1 = time.time() text = transcribe_file( str(audio_parent/audio_path), model, processor ) t = time.time() - t1 print("inference time:", t) print(text) result_list.append({ "index": count, "audio_path": audio_path, "reference": sentence, "duration": duration, "inference_time": round(t, 3), "inference_result": text }) except Exception as e: print(e) except KeyboardInterrupt as e: print(e) import json with open("csv/whisper_finetuned_dataset_results.json", "w", encoding="utf-8") as f: json.dump(result_list, f, ensure_ascii=False, indent=2) def run_test_emilia(): from test_data.audios import read_emilia model, processor = load_model() parent = Path("../test_data/ZH-B000008") result_list = [] count = 0 try: for audio_path, sentence, duration in read_emilia(parent, count_limit=5000): count += 1 print(f"processing {count}: {audio_path}") t1 = time.time() text = transcribe_file( str(audio_path), model, processor ) t = time.time() - t1 print("inference time:", t) print(text) result_list.append({ "index": count, "audio_path": audio_path.name, "reference": sentence, "duration": duration, "inference_time": round(t, 3), "inference_result": text }) except Exception as e: print(e) except KeyboardInterrupt as e: print(e) import json with open("csv/whisper_finetune_emilia_results.json", "w", encoding="utf-8") as f: json.dump(result_list, f, ensure_ascii=False, indent=2) def run_test_st(): from test_data.audios import read_st model, processor = load_model() # parent = Path("../test_data/ST-CMDS-20170001_1-OS") result_list = [] count = 0 try: for audio_path, sentence in read_st(count_limit=5000): count += 1 print(f"processing {count}: {audio_path}") t1 = time.time() text = transcribe_file( str(audio_path), model, processor ) t = time.time() - t1 print("inference time:", t) print(text) result_list.append({ "index": count, "audio_path": audio_path.name, "reference": sentence, # "duration": duration, "inference_time": round(t, 3), "inference_result": text }) except Exception as e: print(e) except KeyboardInterrupt as e: print(e) import json with open("csv/whisper_finetune_st_results.json", "w", encoding="utf-8") as f: json.dump(result_list, f, ensure_ascii=False, indent=2) def run_test_wenet(): from test_data.audios import read_wenet model, processor = load_model() result_list = [] count = 0 try: for audio_path, sentence in read_wenet(count_limit=5000): count += 1 print(f"processing {count}: {audio_path}") t1 = time.time() text = transcribe_file( str(audio_path), model, processor ) t = time.time() - t1 print("inference time:", t) print(text) result_list.append({ "index": count, "audio_path": audio_path.name, "reference": sentence, # "duration": duration, "inference_time": round(t, 3), "inference_result": text }) except Exception as e: print(e) except KeyboardInterrupt as e: print(e) import json with open("csv/whisper_finetune_wenet_results.json", "w", encoding="utf-8") as f: json.dump(result_list, f, ensure_ascii=False, indent=2) if __name__ == "__main__": # main() # run_recordings() # run_test_dataset() # run_test_emilia() run_test_wenet()