| | |
| | |
| | import argparse |
| | import json |
| | import os |
| | from pathlib import Path |
| | import sys |
| |
|
| | pwd = os.path.abspath(os.path.dirname(__file__)) |
| | sys.path.append(os.path.join(pwd, "../../")) |
| |
|
| | import librosa |
| | from gradio_client import Client |
| | import numpy as np |
| | from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score |
| | from tqdm import tqdm |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--test_set", |
| | default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\vad", |
| | type=str |
| | ) |
| | parser.add_argument( |
| | "--output_file", |
| | default=r"native_silero_vad.jsonl", |
| | type=str |
| | ) |
| | parser.add_argument( |
| | "--vad_engine", |
| | |
| | |
| | default="native_silero_vad", |
| | type=str |
| | ) |
| | parser.add_argument("--expected_sample_rate", default=8000, type=int) |
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def get_metrics(ground_truth, predictions, total_duration, step=0.01): |
| | """ |
| | 基于时间点离散化的评估方法 |
| | :param ground_truth: 真实区间列表,格式 [[start1, end1], [start2, end2], ...] |
| | :param predictions: 预测区间列表,格式同上 |
| | :param total_duration: 音频总时长(秒) |
| | :param step: 时间离散化步长(默认10ms) |
| | :return: 评估指标字典 |
| | """ |
| | |
| | time_points = np.arange(0, total_duration, step) |
| |
|
| | |
| | y_true = np.zeros_like(time_points, dtype=int) |
| | y_pred = np.zeros_like(time_points, dtype=int) |
| |
|
| | |
| | for start, end in ground_truth: |
| | mask = (time_points >= start) & (time_points <= end) |
| | y_true[mask] = 1 |
| |
|
| | |
| | for start, end in predictions: |
| | mask = (time_points >= start) & (time_points <= end) |
| | y_pred[mask] = 1 |
| |
|
| | |
| | result = { |
| | "accuracy": accuracy_score(y_true, y_pred), |
| | "precision": precision_score(y_true, y_pred, zero_division=0), |
| | "recall": recall_score(y_true, y_pred, zero_division=0), |
| | "f1": f1_score(y_true, y_pred, zero_division=0) |
| | } |
| | return result |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| |
|
| | client = Client("http://127.0.0.1:7866/") |
| |
|
| | test_set = Path(args.test_set) |
| | output_file = Path(args.output_file) |
| |
|
| | annotation_file = test_set / "vad.json" |
| |
|
| | with open(annotation_file.as_posix(), "r", encoding="utf-8") as f: |
| | annotation = json.load(f) |
| |
|
| | total = 0 |
| | total_accuracy = 0 |
| | total_precision = 0 |
| | total_recall = 0 |
| | total_f1 = 0 |
| | total_duration = 0 |
| | progress_bar = tqdm(desc="evaluation") |
| | with open(output_file.as_posix(), "w", encoding="utf-8") as f: |
| | for row in annotation: |
| | filename = row["filename"] |
| | ground_truth_vad_segments = row["vad_segments"] |
| |
|
| | filename = test_set / filename |
| |
|
| | _, _, _, message = client.predict( |
| | audio_file_t={ |
| | "path": filename.as_posix(), |
| | "meta": {"_type": "gradio.FileData"} |
| | }, |
| | audio_microphone_t=None, |
| | start_ring_rate=0.5, |
| | end_ring_rate=0.3, |
| | ring_max_length=10, |
| | min_silence_length=6, |
| | max_speech_length=100000, |
| | min_speech_length=15, |
| | engine=args.vad_engine, |
| | api_name="/when_click_vad_button" |
| | ) |
| | js = json.loads(message) |
| | prediction_vad_segments = js["vad_segments"] |
| | duration = js["duration"] |
| |
|
| | metrics = get_metrics(ground_truth_vad_segments, prediction_vad_segments, duration) |
| | accuracy = metrics["accuracy"] |
| | precision = metrics["precision"] |
| | recall = metrics["recall"] |
| | f1 = metrics["f1"] |
| |
|
| | row_ = { |
| | "filename": filename.as_posix(), |
| | "duration": duration, |
| | "ground_truth": ground_truth_vad_segments, |
| | "prediction": prediction_vad_segments, |
| |
|
| | "accuracy": accuracy, |
| | "precision": precision, |
| | "recall": recall, |
| | "f1": f1, |
| | } |
| | row_ = json.dumps(row_, ensure_ascii=False) |
| | f.write(f"{row_}\n") |
| |
|
| | total += 1 |
| | total_duration += duration |
| | total_accuracy += accuracy * duration |
| | total_precision += precision * duration |
| | total_recall += recall * duration |
| | total_f1 += f1 * duration |
| |
|
| | average_accuracy = total_accuracy / total_duration |
| | average_precision = total_precision / total_duration |
| | average_recall = total_recall / total_duration |
| | average_f1 = total_f1 / total_duration |
| |
|
| | progress_bar.update(1) |
| | progress_bar.set_postfix({ |
| | "total": total, |
| | "accuracy": average_accuracy, |
| | "precision": average_precision, |
| | "recall": average_recall, |
| | "f1": average_f1, |
| | "total_duration": f"{round(total_duration / 60, 4)}min", |
| | }) |
| |
|
| | return |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|