Spaces:
Paused
Paused
File size: 8,102 Bytes
2967cdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Computes speaker similarity (SIM-o) using a WavLM-based
ECAPA-TDNN speaker verification model.
"""
import argparse
import logging
import os
import warnings
from typing import List
import numpy as np
import torch
from tqdm import tqdm
from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
from zipvoice.eval.utils import load_waveform
warnings.filterwarnings("ignore")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Calculate speaker similarity (SIM-o) score."
)
parser.add_argument(
"--wav-path",
type=str,
required=True,
help="Path to the directory containing evaluated speech files.",
)
parser.add_argument(
"--test-list",
type=str,
required=True,
help="Path to the file list that contains the correspondence between prompts "
"and evaluated speech. Each line contains (audio_name, prompt_text_1, "
"prompt_text_2, prompt_audio_1, prompt_audio_2, text) separated by tabs.",
)
parser.add_argument(
"--model-dir",
type=str,
required=True,
help="Local path of our evaluatioin model repository."
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
"Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
"and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
)
parser.add_argument(
"--extension",
type=str,
default="wav",
help="Extension of the speech files. Default: wav",
)
return parser
class SpeakerSimilarity:
"""
Computes speaker similarity (SIM-o) using a WavLM-based
ECAPA-TDNN speaker verification model.
"""
def __init__(
self,
sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
ssl_model_path: str = "speaker_similarity/wavlm_large/",
):
"""
Initializes the speaker similarity evaluator with the specified models.
Args:
sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
ssl_model_path (str): Path of the wavlm SSL model directory.
"""
self.sample_rate = 16000
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info(f"Using device: {self.device}")
self.model = ECAPA_TDNN_WAVLM(
feat_dim=1024,
channels=512,
emb_dim=256,
sr=self.sample_rate,
ssl_model_path=ssl_model_path,
)
state_dict = torch.load(
sv_model_path, map_location=lambda storage, loc: storage
)
self.model.load_state_dict(state_dict["model"], strict=False)
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
def get_embeddings(self, wav_paths: List[str]) -> List[torch.Tensor]:
"""
Extracts speaker embeddings from a list of audio files.
Args:
wav_paths (List[str]): List of paths to audio files.
Returns:
List[torch.Tensor]: List of speaker embeddings.
"""
embeddings = []
for wav_path in tqdm(wav_paths, desc="Extracting speaker embeddings"):
# Load and preprocess waveform
speech = load_waveform(
wav_path, self.sample_rate, device=self.device, max_seconds=120
)
# Extract embedding
embedding = self.model([speech])
embeddings.append(embedding)
return embeddings
def score(self, wav_path: str, extension: str, test_list: str) -> float:
"""
Computes the Speaker Similarity (SIM-o) score between reference and
evaluated speech.
Args:
wav_path (str): Path to the directory containing evaluated speech files.
test_list (str): Path to the test list file mapping evaluated files
to reference prompts.
Returns:
float: Average similarity score between reference and evaluated embeddings.
"""
logging.info(f"Calculating Speaker Similarity (SIM-o) score for {wav_path}")
# Read test pairs
try:
with open(test_list, "r", encoding="utf-8") as f:
lines = [line.strip().split("\t") for line in f if line.strip()]
except Exception as e:
logging.error(f"Failed to read test list: {e}")
raise
if not lines:
raise ValueError(f"Test list {test_list} is empty or malformed")
# Parse test pairs
prompt_wavs = []
eval_wavs = []
for line in lines:
if len(line) != 4:
raise ValueError(f"Invalid line: {line}")
wav_name, prompt_text, prompt_wav, text = line
eval_wav_path = os.path.join(wav_path, f"{wav_name}.{extension}")
# Validate file existence
if not os.path.exists(prompt_wav):
raise FileNotFoundError(f"Prompt file not found: {prompt_wav}")
if not os.path.exists(eval_wav_path):
raise FileNotFoundError(f"Evaluated file not found: {eval_wav_path}")
prompt_wavs.append(prompt_wav)
eval_wavs.append(eval_wav_path)
logging.info(f"Found {len(prompt_wavs)} valid test pairs")
# Extract embeddings
prompt_embeddings = self.get_embeddings(prompt_wavs)
eval_embeddings = self.get_embeddings(eval_wavs)
if len(prompt_embeddings) != len(eval_embeddings):
raise RuntimeError(
f"Mismatch: {len(prompt_embeddings)} prompt vs "
f" {len(eval_embeddings)} eval embeddings"
)
# Calculate similarity scores
scores = []
for prompt_emb, eval_emb in zip(prompt_embeddings, eval_embeddings):
# Compute cosine similarity
similarity = torch.nn.functional.cosine_similarity(
prompt_emb, eval_emb, dim=-1
)
scores.append(similarity.item())
return float(np.mean(scores))
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
parser = get_parser()
args = parser.parse_args()
# Initialize evaluator
sv_model_path = os.path.join(
args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
)
ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
logging.error(
"Please download evaluation models from "
"https://huggingface.co/k2-fsa/TTS_eval_models"
" and pass this dir with --model-dir"
)
exit(1)
sim_evaluator = SpeakerSimilarity(
sv_model_path=sv_model_path, ssl_model_path=ssl_model_path
)
# Compute similarity score
score = sim_evaluator.score(args.wav_path, args.extension, args.test_list)
print("-" * 50)
logging.info(f"SIM-o score: {score:.3f}")
print("-" * 50)
|