Luigi's picture
Restore complete zipvoice package with all source files
2967cdb
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
- A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
- A pyannote pipeline for speaker diarization (segmenting speakers).
"""
import argparse
import logging
import os
import warnings
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from pyannote.audio import Pipeline
from tqdm import tqdm
from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
from zipvoice.eval.utils import load_waveform
warnings.filterwarnings("ignore")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Calculate concatenated maximum permutation speaker "
"similarity (cpSIM) score."
)
parser.add_argument(
"--wav-path",
type=str,
required=True,
help="Path to the directory containing evaluated speech files.",
)
parser.add_argument(
"--test-list",
type=str,
help="Path to the tsv file for speaker splitted prompts. "
"Each line contains (audio_name, prompt_text_1, prompt_text_2, "
"prompt_audio_1, prompt_audio_2, text) separated by tabs.",
)
parser.add_argument(
"--test-list-merge",
type=str,
help="Path to the tsv file for merged dialogue prompts. "
"Each line contains (audio_name, prompt_text_dialogue, "
"prompt_audio_dialogue, text) separated by tabs.",
)
parser.add_argument(
"--model-dir",
type=str,
required=True,
help="Local path of our evaluatioin model repository."
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
"Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
", 'tts_eval_models/speaker_similarity/wavlm_large/' and "
"tts_eval_models/speaker_similarity/pyannote/ in this script",
)
parser.add_argument(
"--extension",
type=str,
default="wav",
help="Extension of the speech files. Default: wav",
)
return parser
class CpSpeakerSimilarity:
"""
Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
- A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
- A pyannote pipeline for speaker diarization (segmenting speakers).
"""
def __init__(
self,
sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
ssl_model_path: str = "speaker_similarity/wavlm_large/",
pyannote_model_path: str = "speaker_similarity/pyannote/",
):
"""
Initializes the cpSIM evaluator with the specified models.
Args:
sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
ssl_model_path (str): Path of the wavlm SSL model directory.
pyannote_model_path (str): Path of the pyannote diarization model directory.
"""
self.sample_rate = 16000
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info(f"Using device: {self.device}")
# Initialize speaker verification model
self.sv_model = ECAPA_TDNN_WAVLM(
feat_dim=1024,
channels=512,
emb_dim=256,
sr=self.sample_rate,
ssl_model_path=ssl_model_path,
)
state_dict = torch.load(
sv_model_path, map_location=lambda storage, loc: storage
)
self.sv_model.load_state_dict(state_dict["model"], strict=False)
self.sv_model.to(self.device)
self.sv_model.eval()
# Initialize diarization pipeline
self.diarization_pipeline = Pipeline.from_pretrained(
os.path.join(pyannote_model_path, "pyannote_diarization_config.yaml")
)
self.diarization_pipeline.to(self.device)
@torch.no_grad()
def get_embeddings_with_diarization(
self, audio_paths: List[str]
) -> List[List[torch.Tensor]]:
"""
Extracts speaker embeddings from audio files
with speaker diarization (for 2-speaker conversations).
Args:
audio_paths: List of paths to audio files (each containing 2 speakers).
Returns:
List of embedding pairs, where each pair is
[embedding_speaker1, embedding_speaker2].
"""
embeddings_list = []
for audio_path in tqdm(
audio_paths, desc="Extracting embeddings with diarization"
):
# Load audio waveform
speech = load_waveform(
audio_path, self.sample_rate, device=self.device, max_seconds=120
)
# Perform speaker diarization (assumes 2 speakers)
diarization = self.diarization_pipeline(
{"waveform": speech.unsqueeze(0), "sample_rate": self.sample_rate},
num_speakers=2,
)
# Collect speech chunks for each speaker
speaker1_chunks = []
speaker2_chunks = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
start_frame = int(turn.start * self.sample_rate)
end_frame = int(turn.end * self.sample_rate)
chunk = speech[start_frame:end_frame]
if speaker == "SPEAKER_00":
speaker1_chunks.append(chunk)
elif speaker == "SPEAKER_01":
speaker2_chunks.append(chunk)
# Handle cases where diarization fails to detect 2 speakers
if not (speaker1_chunks and speaker2_chunks):
logging.debug(
f"Insufficient speaker chunks in {audio_path} "
f"using full audio for both speakers"
)
speaker1_speech = speech
speaker2_speech = speech
else:
speaker1_speech = torch.cat(speaker1_chunks, dim=0)
speaker2_speech = torch.cat(speaker2_chunks, dim=0)
# Extract embeddings with no gradient computation
try:
emb_speaker1 = self.sv_model([speaker1_speech])
emb_speaker2 = self.sv_model([speaker2_speech])
except Exception as e:
logging.debug(
f"Encountered an error {e} when extracting embeddings with "
f"segmented speech, will use full audio for both speakers."
)
emb_speaker1 = self.sv_model([speech])
emb_speaker2 = self.sv_model([speech])
embeddings_list.append([emb_speaker1, emb_speaker2])
return embeddings_list
@torch.no_grad()
def get_embeddings_from_pairs(
self, audio_pairs: List[Tuple[str, str]]
) -> List[List[torch.Tensor]]:
"""
Extracts speaker embeddings from pairs of single-speaker audio files.
Args:
audio_pairs: List of tuples (path_speaker1, path_speaker2).
Returns:
List of embedding pairs, where each pair is
[embedding_speaker1, embedding_speaker2].
"""
embeddings_list = []
for (path1, path2) in tqdm(
audio_pairs, desc="Extracting embeddings from pairs"
):
# Load audio for each speaker
speech1 = load_waveform(path1, self.sample_rate, device=self.device)
speech2 = load_waveform(path2, self.sample_rate, device=self.device)
# Extract embeddings
emb_speaker1 = self.sv_model([speech1])
emb_speaker2 = self.sv_model([speech2])
embeddings_list.append([emb_speaker1, emb_speaker2])
return embeddings_list
def score(
self,
wav_path: str,
extension: str,
test_list: str,
prompt_mode: str,
) -> float:
"""
Computes the cpSIM score by comparing embeddings of prompt and evaluated speech.
Args:
wav_path: Directory containing evaluated speech files.
test_list: Path to test list file mapping evaluated files to prompts.
prompt_mode: Either "merge" (2-speaker prompt) or "split"
(two single-speaker prompts).
Returns:
Average cpSIM score across all test pairs.
"""
logging.info(f"Calculating cpSIM score for {wav_path} (mode: {prompt_mode})")
# Load and parse test list
try:
with open(test_list, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f if line.strip()]
except Exception as e:
logging.error(f"Failed to read test list {test_list}: {e}")
raise
if not lines:
raise ValueError(f"Test list {test_list} is empty")
# Collect valid prompt-eval audio pairs
prompt_audios = [] # For "merge": [path]; for "split": [(path1, path2)]
eval_audios = []
for line_num, line in enumerate(lines, 1):
parts = line.split("\t")
if prompt_mode == "merge":
if len(parts) != 4:
raise ValueError(f"Expected 4 columns, got {len(parts)}")
audio_name, prompt_text, prompt_audio, text = parts
eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
prompt_audios.append(prompt_audio)
elif prompt_mode == "split":
if len(parts) != 6:
raise ValueError(f"Expected 6 columns, got {len(parts)}")
(
audio_name,
prompt_text1,
prompt_text2,
prompt_audio_1,
prompt_audio_2,
text,
) = parts
eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
prompt_audios.append((prompt_audio_1, prompt_audio_2))
else:
raise ValueError(f"Invalid prompt_mode: {prompt_mode}")
# Validate file existence
if not os.path.exists(eval_audio_path):
raise FileNotFoundError(f"Evaluated file not found: {eval_audio_path}")
if prompt_mode == "merge":
if not os.path.exists(prompt_audio):
raise FileNotFoundError(
f"Prompt merge file not found: {prompt_audio}"
)
else:
if not (
os.path.exists(prompt_audio_1) and os.path.exists(prompt_audio_2)
):
raise FileNotFoundError(
f"One or more prompt files missing in {prompt_audio_1}, "
f"{prompt_audio_2}"
)
eval_audios.append(eval_audio_path)
if not prompt_audios or not eval_audios:
raise ValueError(f"No valid prompt-eval pairs found in {test_list}")
logging.info(f"Processing {len(prompt_audios)} valid test pairs")
# Extract embeddings for prompts and evaluations
if prompt_mode == "merge":
prompt_embeddings = self.get_embeddings_with_diarization(prompt_audios)
else:
prompt_embeddings = self.get_embeddings_from_pairs(prompt_audios)
eval_embeddings = self.get_embeddings_with_diarization(eval_audios)
if len(prompt_embeddings) != len(eval_embeddings):
raise RuntimeError(
f"Mismatch: {len(prompt_embeddings)} prompt vs "
f" {len(eval_embeddings)} eval embeddings"
)
# Calculate maximum permutation similarity scores
scores = []
for prompt_embs, eval_embs in zip(prompt_embeddings, eval_embeddings):
# Prompt and eval each have 2 embeddings: [emb1, emb2]
sim1 = F.cosine_similarity(
prompt_embs[0], eval_embs[0], dim=-1
) + F.cosine_similarity(prompt_embs[1], eval_embs[1], dim=-1)
sim2 = F.cosine_similarity(
prompt_embs[0], eval_embs[1], dim=-1
) + F.cosine_similarity(prompt_embs[1], eval_embs[0], dim=-1)
max_sim = torch.max(sim1, sim2).item() / 2 # Average the sum
scores.append(max_sim)
return float(np.mean(scores))
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
parser = get_parser()
args = parser.parse_args()
# Validate test list arguments
if not (args.test_list or args.test_list_merge):
raise ValueError("Either --test-list or --test-list-merge must be provided")
if args.test_list and args.test_list_merge:
raise ValueError(
"Only one of --test-list-split or --test-list-merge can be provided"
)
# Determine mode and test list
if args.test_list:
prompt_mode = "split"
test_list = args.test_list
else:
prompt_mode = "merge"
test_list = args.test_list_merge
# Initialize evaluator
sv_model_path = os.path.join(
args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
)
ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
pyannote_model_path = os.path.join(args.model_dir, "speaker_similarity/pyannote/")
if (
not os.path.exists(sv_model_path)
or not os.path.exists(ssl_model_path)
or not os.path.exists(pyannote_model_path)
):
logging.error(
"Please download evaluation models from "
"https://huggingface.co/k2-fsa/TTS_eval_models"
" and pass this dir with --model-dir"
)
exit(1)
cp_sim = CpSpeakerSimilarity(
sv_model_path=sv_model_path,
ssl_model_path=ssl_model_path,
pyannote_model_path=pyannote_model_path,
)
# Compute similarity score
score = cp_sim.score(
wav_path=args.wav_path,
extension=args.extension,
test_list=test_list,
prompt_mode=prompt_mode,
)
print("-" * 50)
logging.info(f"cpSIM score: {score:.3f}")
print("-" * 50)