MCplayer's picture
speech similarity model
29c0409
import json
import re
import os
from typing import List, Dict, Tuple, Any
import numpy as np
from pathlib import Path
import torch
import torchaudio
import torchaudio.functional as F
import logging
import wespeaker
import shutil
from datetime import datetime
import multiprocessing as mp
from functools import partial
import math
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import random # 添加random模块用于shuffle
# 设置multiprocessing启动方式为spawn(CUDA兼容)
mp.set_start_method('spawn', force=True)
# 引用词对齐模块
from alignment import AlignmentModel, batch_get_alignment_result
class SpeakerSimilarityEvaluator:
"""音色相似度评估器"""
def __init__(self, device="cuda",
alignment_model_dir='/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa',
wespeaker_model_dir='/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft',
output_dir="./evaluation_results",
language="ZH",
similarity_max_workers=8):
"""初始化评估器"""
self.device = device
self.alignment_model_dir = alignment_model_dir
self.wespeaker_model_dir = wespeaker_model_dir
self.language = language.upper() # 添加语言参数
self.similarity_max_workers = similarity_max_workers # 相似度计算线程数
# 先设置日志系统
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# 设置输出目录结构
self.output_dir = Path(output_dir)
self.segments_dir = self.output_dir / "segments" # 分割后的音频片段
self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段
self.temp_dir = self.output_dir / "temp" # 临时文件
self.results_dir = self.output_dir / "results" # 评估结果
self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件
self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录
# 创建所有必要的目录
self._create_output_directories()
# 在多进程环境中延迟模型初始化
self.alignment_model = None
self.similarity_model = None
# 线程局部存储,用于线程安全的模型访问
self._thread_local = threading.local()
# 记录运行信息
self.logger.info(f"评估结果将保存到: {self.output_dir}")
self.logger.info(f"对齐信息将保存到: {self.alignment_dir}")
self.logger.info(f"使用语言: {self.language}")
def _create_output_directories(self):
"""创建输出目录结构"""
for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir,
self.results_dir, self.temp_results_dir, self.alignment_dir]:
dir_path.mkdir(parents=True, exist_ok=True)
def _get_safe_filename(self, text: str, max_length: int = 50) -> str:
"""生成安全的文件名"""
# 移除特殊字符,只保留中文、英文、数字和基本符号
safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text)
# 限制长度
if len(safe_text) > max_length:
safe_text = safe_text[:max_length]
# 替换空格为下划线
safe_text = safe_text.replace(' ', '_')
return safe_text if safe_text else "unnamed"
def _clean_temp_files(self):
"""清理临时文件,但保留临时目录"""
if self.temp_dir.exists():
# 只删除临时目录中的文件,不删除目录本身
for file_path in self.temp_dir.iterdir():
if file_path.is_file():
try:
file_path.unlink()
except Exception as e:
self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}")
else:
# 如果临时目录不存在,重新创建
self.temp_dir.mkdir(parents=True, exist_ok=True)
def _init_models_if_needed(self):
"""延迟初始化模型(用于多进程环境)"""
# 初始化对齐模型 - 修正参数顺序
if self.alignment_model is None:
# 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device)
self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir)
# 初始化相似度模型
if self.similarity_model is None:
self._load_wespeaker_model(self.wespeaker_model_dir)
def _is_english_text(self, text: str) -> bool:
"""简单判断文本是否主要是英文"""
# 计算英文字符的比例
english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
total_chars = sum(1 for c in text if c.isalpha())
if total_chars == 0:
return False
return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文
def _detect_language_from_text(self, text: str) -> str:
"""从文本内容检测语言"""
clean_text = self.remove_speaker_tags(text)
if self._is_english_text(clean_text):
return "EN"
else:
return "ZH"
def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"):
"""
保存对齐信息到单独的JSON文件
Args:
alignment_data: 对齐信息数据
input_id: 输入ID
file_type: 文件类型 ("output", "prompt", "segment")
"""
try:
safe_input_id = self._get_safe_filename(input_id)
alignment_filename = f"{safe_input_id}_{file_type}_alignment.json"
alignment_path = self.alignment_dir / alignment_filename
# 添加元数据
alignment_info = {
'input_id': input_id,
'file_type': file_type,
'language': self.language,
'timestamp': datetime.now().isoformat(),
'alignment_data': alignment_data
}
with open(alignment_path, 'w', encoding='utf-8') as f:
json.dump(alignment_info, f, ensure_ascii=False, indent=2)
self.logger.info(f"对齐信息已保存: {alignment_path}")
return str(alignment_path)
except Exception as e:
self.logger.error(f"保存对齐信息失败: {e}")
return None
def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]],
text_segments: List[Dict[str, Any]],
input_id: str, audio_path: str,
original_text: str, processed_text: str):
"""
保存详细的对齐信息,包括分段信息
Args:
alignments: 对齐结果列表
text_segments: 文本分段信息
input_id: 输入ID
audio_path: 音频文件路径
original_text: 原始文本
processed_text: 处理后的文本
"""
alignment_data = {
'original_text': original_text,
'processed_text': processed_text,
'audio_path': audio_path,
'language': self.language,
'total_alignments': len(alignments),
'total_segments': len(text_segments),
'alignments': alignments,
'text_segments': text_segments,
'segment_alignment_mapping': []
}
# 建立文本段和对齐结果的映射关系
for segment in text_segments:
segment_mapping = {
'segment_id': segment.get('segment_id', 0),
'segment_text': segment.get('text', ''),
'speaker_label': segment.get('speaker_label', ''),
'start_time': segment.get('start_time', 0.0),
'end_time': segment.get('end_time', 0.0),
'corresponding_alignments': []
}
# 找到对应的对齐项
segment_start = segment.get('start_time', 0.0)
segment_end = segment.get('end_time', 0.0)
for i, align_item in enumerate(alignments):
align_start = align_item.get('start', 0.0)
align_end = align_item.get('end', 0.0)
# 检查对齐项是否在当前段的时间范围内
if (align_start >= segment_start and align_end <= segment_end) or \
(align_start < segment_end and align_end > segment_start):
segment_mapping['corresponding_alignments'].append({
'alignment_index': i,
'transcript': align_item.get('transcript', ''),
'start': align_start,
'end': align_end,
'score': align_item.get('score', 0.0) if 'score' in align_item else None
})
alignment_data['segment_alignment_mapping'].append(segment_mapping)
return self.save_alignment_info(alignment_data, input_id, "detailed")
def remove_speaker_tags(self, text: str) -> str:
"""删除文本中的说话人标签[S1][S2]"""
return re.sub(r'\[S[12]\]', '', text).strip()
def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]:
"""提取文本中的说话人片段信息"""
segments = []
pattern = r'\[S([12])\]([^[]*)'
matches = re.findall(pattern, text)
for speaker_id, content in matches:
segments.append({
'speaker': f'S{speaker_id}',
'content': content.strip()
})
return segments
def replace_punctuation_with_comma(self, text: str, language: str = None) -> str:
"""将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型"""
# 如果未指定语言,使用类的默认语言设置或自动检测
if language is None:
if hasattr(self, 'language'):
language = self.language
else:
language = self._detect_language_from_text(text)
language = language.upper()
# 根据语言选择逗号类型和处理策略
if language == "EN" or (language == "AUTO" and self._is_english_text(text)):
# 英文处理:先删除撇号,再替换其他标点符号
text = re.sub(r"'", '', text) # 删除撇号(don't -> dont)
target_comma = ',' # 英文逗号
comma_pattern = r',+' # 匹配连续英文逗号
# 更新正则表达式,不包含撇号
text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text)
else:
# 中文处理:包含撇号在替换范围内
target_comma = ',' # 中文逗号
comma_pattern = r',+' # 匹配连续中文逗号
# 更新正则表达式以匹配更多的标点符号
text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text)
text = re.sub(comma_pattern, target_comma, text)
return text.strip(target_comma)
def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]:
"""
文本和音频的词对齐
返回每个词对应的音频时间段
"""
# 确保模型已初始化
self._init_models_if_needed()
# 如果未指定语言,使用类的默认语言设置或自动检测
if language is None:
if hasattr(self, 'language'):
language = self.language
else:
language = self._detect_language_from_text(text)
else:
language = language.upper()
# 加载音频
waveform, sample_rate = torchaudio.load(audio_path)
# 重采样到模型要求的采样率
if sample_rate != self.alignment_model.bundle.sample_rate:
waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate)
# 转换为单声道
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform = waveform.squeeze(0) # 移除批次维度
# 将音频移动到正确的设备
waveform = waveform.to(self.device)
# 执行对齐
try:
alignment_results = batch_get_alignment_result(
self.alignment_model,
[waveform],
[text],
[language]
)
if not alignment_results or not alignment_results[0]:
raise RuntimeError(f"对齐结果为空: {audio_path}")
return alignment_results[0]
except Exception as e:
self.logger.error(f"音频对齐失败: {audio_path}")
self.logger.error(f"错误详情: {e}")
raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}")
def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str):
"""分割音频片段"""
waveform, sample_rate = torchaudio.load(audio_path)
start_frame = int(start_time * sample_rate)
end_frame = int(end_time * sample_rate)
segment = waveform[:, start_frame:end_frame]
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, segment, sample_rate)
return output_path
def concatenate_audio_files(self, audio_files: List[str], output_path: str):
"""拼接多个音频文件"""
if not audio_files:
return
waveforms = []
sample_rate = None
for audio_file in audio_files:
if os.path.exists(audio_file):
waveform, sr = torchaudio.load(audio_file)
if sample_rate is None:
sample_rate = sr
elif sr != sample_rate:
waveform = F.resample(waveform, sr, sample_rate)
waveforms.append(waveform)
if waveforms:
concatenated = torch.cat(waveforms, dim=1)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, concatenated, sample_rate)
def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]:
"""
根据说话人标签分割prompt音频
返回S1和S2的音频片段路径
"""
# 1. 提取说话人片段
speaker_segments = self.extract_speaker_segments(prompt_text)
# 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常
clean_text = self.remove_speaker_tags(prompt_text)
# 检测语言或使用设置的语言
alignment_language = self.language
if alignment_language == "AUTO":
alignment_language = self._detect_language_from_text(clean_text)
alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language)
# 保存prompt对齐信息
prompt_alignment_data = {
'original_text': prompt_text,
'clean_text': clean_text,
'audio_path': prompt_audio,
'language': alignment_language,
'speaker_segments': speaker_segments,
'alignments': alignments
}
self.save_alignment_info(prompt_alignment_data, audio_id, "prompt")
# 3. 根据对齐结果分割音频
s1_segments = []
s2_segments = []
# 为每个说话人片段找到对应的时间段
text_pos = 0
for seg in speaker_segments:
seg_text = seg['content'].strip()
seg_length = len(seg_text)
# 找到这个片段在对齐结果中的起始和结束
start_time = None
end_time = None
current_pos = 0
for align_item in alignments:
item_text = align_item['transcript']
item_length = len(item_text)
if current_pos >= text_pos and current_pos < text_pos + seg_length:
if start_time is None:
start_time = align_item['start']
end_time = align_item['end']
current_pos += item_length
if start_time is not None and end_time is not None:
if seg['speaker'] == 'S1':
s1_segments.append((start_time, end_time))
else:
s2_segments.append((start_time, end_time))
text_pos += seg_length
# 4. 分割并拼接音频片段
safe_audio_id = self._get_safe_filename(audio_id)
prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav")
prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav")
# 分割S1的所有片段
if s1_segments:
s1_temp_segments = []
for i, (start, end) in enumerate(s1_segments):
temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav")
self.split_audio_segment(prompt_audio, start, end, temp_path)
s1_temp_segments.append(temp_path)
# 拼接S1片段
self.concatenate_audio_files(s1_temp_segments, prompts1_path)
# 分割S2的所有片段
if s2_segments:
s2_temp_segments = []
for i, (start, end) in enumerate(s2_segments):
temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav")
self.split_audio_segment(prompt_audio, start, end, temp_path)
s2_temp_segments.append(temp_path)
# 拼接S2片段
self.concatenate_audio_files(s2_temp_segments, prompts2_path)
return prompts1_path, prompts2_path
def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]:
"""
将原始文本按说话人和标点符号同时分割,保持映射关系
支持英文单词级别的处理
"""
segments = []
pattern = r'\[S([12])\]([^[]*)'
matches = re.findall(pattern, original_text)
# 检测语言或使用设置的语言
alignment_language = self.language
if alignment_language == "AUTO":
alignment_language = self._detect_language_from_text(original_text)
segment_id = 0
for speaker_id, content in matches:
speaker = f'S{speaker_id}'
clean_content = content.strip()
comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language)
# 根据语言选择正确的逗号分割
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)):
# 英文:按英文逗号分割,保持单词完整性
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
else:
# 中文:按中文逗号分割
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
for part in parts:
if part.strip():
segments.append({
'segment_id': segment_id,
'text': part.strip(),
'speaker_label': speaker,
'original_speaker_content': clean_content
})
segment_id += 1
return segments
def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]:
"""
根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子
"""
# 1. 获取文本片段和对应的说话人(用于获取speaker标签)
text_segments = self.map_text_segments_to_speakers(text)
# 2. 删除标签并替换标点符号
clean_text = self.remove_speaker_tags(text)
# 3. 检测语言或使用设置的语言
alignment_language = self.language
if alignment_language == "AUTO":
alignment_language = self._detect_language_from_text(clean_text)
# 使用检测到的语言替换标点符号
comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language)
# 4. 词对齐 - 如果失败则直接抛出异常
alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language)
# 5. 根据标点符号划分句子
segments = []
safe_audio_id = self._get_safe_filename(audio_id)
# 确定标点符号(根据语言选择,英文不包含撇号)
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号
else:
punctuation_chars = set([',', '。', '!', '?', ';', ':'])
# 顺序扫描对齐结果,根据标点符号划分句子
sentence_start_idx = 0
sentence_alignments = []
segment_id = 0
for i, align_item in enumerate(alignments):
transcript = align_item['transcript']
sentence_alignments.append(align_item)
# 检查是否包含标点符号(句子结束标志)
has_punctuation = any(punct in transcript for punct in punctuation_chars)
if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词
# 创建句子片段
if sentence_alignments:
# 获取句子的开始和结束时间
start_time = sentence_alignments[0]['start']
end_time = sentence_alignments[-1]['end']
# 构建句子文本(去除标点符号)
sentence_text_parts = []
for align in sentence_alignments:
# 根据语言选择不同的清理策略
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
# 英文:去除标点符号,但保留撇号已被删除的单词
clean_transcript = align['transcript'].rstrip(',.!?;:')
else:
# 中文:去除中文标点符号
clean_transcript = align['transcript'].rstrip(',。!?;:')
if clean_transcript.strip():
sentence_text_parts.append(clean_transcript)
# 根据语言选择连接方式
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接
else:
sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接
if sentence_text: # 只有非空句子才处理
# 确定说话人标签(从原始text_segments中获取,如果可能的话)
speaker_label = "S1" # 默认
if segment_id < len(text_segments):
speaker_label = text_segments[segment_id]['speaker_label']
elif text_segments:
# 如果超出范围,使用最后一个片段的speaker
speaker_label = text_segments[-1]['speaker_label']
# 生成音频文件路径
safe_text = self._get_safe_filename(sentence_text, 30)
audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav")
# 分割音频
try:
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
except Exception as e:
self.logger.error(f"分割音频失败: {e}")
# 使用默认时间间隔
start_time = segment_id * 1.0
end_time = (segment_id + 1) * 1.0
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
# 创建segment
segment = {
'segment_id': segment_id,
'text': sentence_text,
'speaker_label': speaker_label,
'original_speaker_content': sentence_text, # 这里简化处理
'audio_path': audio_path,
'start_time': start_time,
'end_time': end_time
}
segments.append(segment)
self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s")
segment_id += 1
# 重置为下一个句子
sentence_alignments = []
sentence_start_idx = i + 1
# 保存详细的对齐信息
self.save_detailed_alignment_info(
alignments, segments, audio_id, output_audio, text, comma_text
)
self.logger.info(f"总共分割出 {len(segments)} 个句子片段")
return segments
def _get_thread_local_similarity_model(self):
"""获取线程局部的相似度模型实例(线程安全)"""
if not hasattr(self._thread_local, 'similarity_model'):
# 为当前线程创建独立的模型实例
self._thread_local.similarity_model = self._create_similarity_model()
return self._thread_local.similarity_model
def _create_similarity_model(self):
"""创建新的相似度模型实例"""
try:
import wespeaker
# 使用与主模型相同的加载逻辑
local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
try:
model = wespeaker.load_model_local(local_model_path)
return model
except Exception as e:
self.logger.warning(f"加载指定本地模型失败: {e}")
# 回退方案
if os.path.exists(self.wespeaker_model_dir):
try:
model = wespeaker.load_model_local(self.wespeaker_model_dir)
return model
except Exception as e:
self.logger.warning(f"加载传入本地模型失败: {e}")
# 最终回退到预训练模型
try:
model = wespeaker.load_model('chinese')
return model
except Exception as e:
model = wespeaker.load_model('english')
return model
except Exception as e:
self.logger.error(f"创建相似度模型失败: {e}")
raise
def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float:
"""
线程安全的音色相似度计算
对于过短的音频片段,通过复制来达到最小长度要求
"""
try:
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
return None
# 获取线程局部的模型实例
similarity_model = self._get_thread_local_similarity_model()
# 检查并处理音频文件长度
def process_audio_for_similarity(audio_path, min_duration=0.1):
"""
处理音频文件,如果过短则复制到满足最小长度要求
返回处理后的音频路径和是否为临时文件的标志
"""
try:
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration >= min_duration:
# 音频长度足够,直接返回原路径
return audio_path, False
# 音频过短,需要复制
repeat_times = math.ceil(min_duration / duration)
thread_id = threading.get_ident()
# 复制音频
repeated_waveform = waveform.repeat(1, repeat_times)
# 生成临时文件路径(包含线程ID避免冲突)
temp_filename = f"temp_{thread_id}_{os.path.basename(audio_path)}"
temp_path = str(self.temp_dir / temp_filename)
# 保存复制后的音频
torchaudio.save(temp_path, repeated_waveform, sample_rate)
return temp_path, True
except Exception as e:
self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
return audio_path, False
# 处理两个音频文件
processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
# 计算相似度
similarity = similarity_model.compute_similarity(processed_audio1, processed_audio2)
# 清理临时文件
if is_temp1 and os.path.exists(processed_audio1):
try:
os.remove(processed_audio1)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
if is_temp2 and os.path.exists(processed_audio2):
try:
os.remove(processed_audio2)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
return float(similarity)
except Exception as e:
# 检查是否是窗口大小错误或其他计算错误
if "choose a window size" in str(e) or "window size" in str(e):
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
return None
else:
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
return None
def calculate_segment_similarities_parallel(self, output_segments: List[Dict[str, Any]],
prompts1_path: str, prompts2_path: str) -> List[Dict[str, Any]]:
"""
并行计算所有segments的相似度
Args:
output_segments: 音频segments列表
prompts1_path: S1 prompt音频路径
prompts2_path: S2 prompt音频路径
Returns:
包含相似度信息的segment列表
"""
def calculate_single_segment_similarity(segment):
"""计算单个segment与两个prompts的相似度"""
try:
# 使用线程安全的相似度计算方法
sim1 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path)
sim2 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path)
return {
'segment': segment,
'sim1': sim1,
'sim2': sim2,
'success': True
}
except Exception as e:
self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}")
return {
'segment': segment,
'sim1': None,
'sim2': None,
'success': False
}
# 使用线程池并行处理所有segments
self.logger.info(f"开始并行计算 {len(output_segments)} 个segments的相似度,使用 {self.similarity_max_workers} 个线程")
results = []
with ThreadPoolExecutor(max_workers=self.similarity_max_workers) as executor:
# 提交所有segment任务
future_to_segment = {
executor.submit(calculate_single_segment_similarity, segment): segment
for segment in output_segments
}
# 收集结果(保持原有顺序)
segment_to_result = {}
completed_count = 0
for future in as_completed(future_to_segment):
result = future.result()
segment_id = result['segment']['segment_id']
segment_to_result[segment_id] = result
completed_count += 1
# 每完成10个segment报告一次进度
if completed_count % 10 == 0 or completed_count == len(output_segments):
self.logger.info(f"相似度计算进度: {completed_count}/{len(output_segments)}")
# 按segment_id顺序返回结果
for segment in output_segments:
segment_id = segment['segment_id']
if segment_id in segment_to_result:
results.append(segment_to_result[segment_id])
return results
def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]:
"""评估单个输入的音色相似度"""
# 生成输入ID
if input_id is None:
input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}")
# 1. 获取或分割prompt音频
prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt")
# 2. 分割output音频(这里会保存详细对齐信息)
output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output")
# 3. 并行计算每小段的相似度
similarity_results = self.calculate_segment_similarities_parallel(
output_segments, prompts1_path, prompts2_path
)
# 4. 处理相似度结果
segment_results = []
correct_predictions = 0
total_segments = 0 # 只计算有效段数
label_similarities = [] # 每小段与其标签的相似度
skipped_segments = 0 # 跳过的段数
for sim_result in similarity_results:
segment = sim_result['segment']
sim1 = sim_result['sim1']
sim2 = sim_result['sim2']
# 如果任一相似度为None(音频过短或计算失败),跳过该段
if sim1 is None or sim2 is None:
skipped_segments += 1
self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败")
continue
# 只有有效段才参与计算
total_segments += 1
# 判断实际音色
predicted_speaker = 'S1' if sim1 > sim2 else 'S2'
actual_speaker = segment['speaker_label']
is_correct = predicted_speaker == actual_speaker
if is_correct:
correct_predictions += 1
# 计算与标签的相似度
if actual_speaker == 'S1':
label_similarity = sim1
else:
label_similarity = sim2
label_similarities.append(label_similarity)
segment_result = {
'segment_id': segment['segment_id'],
'text': segment['text'],
'speaker_label': actual_speaker,
'predicted_speaker': predicted_speaker,
'sim1': sim1,
'sim2': sim2,
'label_similarity': label_similarity,
'is_correct': is_correct,
'audio_path': segment['audio_path'],
'start_time': segment.get('start_time', 0.0),
'end_time': segment.get('end_time', 1.0)
}
segment_results.append(segment_result)
# 4. 计算整体指标(只基于有效段)
accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0
average_similarity = np.mean(label_similarities) if label_similarities else 0.0
# 5. 保存评估结果的对齐信息摘要
evaluation_alignment_summary = {
'input_id': input_id,
'language': self.language,
'prompt_alignment_files': [
f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
],
'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
'total_segments': total_segments,
'total_alignments_processed': len(output_segments),
'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0
}
self.save_alignment_info(evaluation_alignment_summary, input_id, "summary")
result = {
'input_id': input_id,
'language': self.language,
'input_data': data, # 保存原始输入数据
'prompts1_path': prompts1_path,
'prompts2_path': prompts2_path,
'segments': segment_results,
'accuracy': accuracy,
'average_similarity': average_similarity,
'total_segments': total_segments, # 有效段数
'correct_predictions': correct_predictions,
'skipped_segments': skipped_segments, # 跳过的段数
'original_total_segments': len(output_segments), # 原始总段数
'alignment_files': {
'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json",
'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
},
'timestamp': datetime.now().isoformat()
}
self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}")
return result
def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None):
"""保存结果到JSONL文件"""
if filename is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
output_path = self.results_dir / filename
with open(output_path, 'w', encoding='utf-8') as f:
for result in results:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
return str(output_path)
def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None):
"""保存汇总报告"""
if filename is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json"
summary_path = self.results_dir / filename
# 计算总体统计
total_accuracy = np.mean([r['accuracy'] for r in results])
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
total_segments = sum([r['total_segments'] for r in results])
total_correct = sum([r['correct_predictions'] for r in results])
summary = {
'evaluation_summary': {
'language': self.language,
'total_inputs': len(results),
'total_segments': total_segments,
'total_correct_predictions': total_correct,
'overall_accuracy': total_accuracy,
'overall_average_similarity': total_avg_similarity,
'evaluation_timestamp': datetime.now().isoformat(),
'output_directory': str(self.output_dir),
'alignment_directory': str(self.alignment_dir)
},
'per_input_results': [
{
'input_id': r['input_id'],
'language': r.get('language', self.language),
'accuracy': r['accuracy'],
'average_similarity': r['average_similarity'],
'total_segments': r['total_segments'],
'correct_predictions': r['correct_predictions'],
'output_audio_path': r['input_data']['output_audio'],
'alignment_files': r.get('alignment_files', {})
}
for r in results
]
}
with open(summary_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
return str(summary_path)
def process_batch_from_jsonl_parallel(self, jsonl_path: str,
processes_per_gpu: int = 16,
results_filename: str = None,
shuffle_data: bool = True):
"""从JSONL文件并行批量处理输入数据"""
# 加载数据
input_data = self.load_data_from_jsonl(jsonl_path)
if not input_data:
self.logger.error("没有有效的输入数据")
return []
# 对数据进行shuffle,使分配更均匀
if shuffle_data:
random.shuffle(input_data)
self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle")
return self.process_batch_parallel(input_data, processes_per_gpu, results_filename)
def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None):
"""从JSONL文件批量处理输入数据(单进程版本)"""
# 加载数据
input_data = self.load_data_from_jsonl(jsonl_path)
if not input_data:
self.logger.error("没有有效的输入数据")
return []
return self.process_batch_from_data(input_data, results_filename)
def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None):
"""处理数据列表(单进程版本,用于兼容),支持增量写入"""
# 准备结果文件
if results_filename is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
results_path = self.results_dir / results_filename
# 如果文件已存在,删除它(重新开始)
if results_path.exists():
results_path.unlink()
results = []
self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...")
for i, data in enumerate(input_data):
input_id = f"input_{i+1:03d}"
print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}")
try:
result = self.evaluate_single_input(data, input_id=input_id)
results.append(result)
# 增量写入结果
self.append_result_to_jsonl(result, str(results_path))
except Exception as e:
self.logger.error(f"处理输入{input_id}时出错: {e}")
continue
if not results:
self.logger.error("没有成功处理的输入")
return []
# 保存汇总报告
summary_path = self.save_summary_report(results)
# 清理临时文件
self._clean_temp_files()
# 打印总体统计
total_accuracy = np.mean([r['accuracy'] for r in results])
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
print(f"\n=== 评估完成 ===")
print(f"使用语言: {self.language}")
print(f"总体准确率: {total_accuracy:.3f}")
print(f"总体平均相似度: {total_avg_similarity:.3f}")
print(f"详细结果已保存到: {results_path}")
print(f"汇总报告已保存到: {summary_path}")
print(f"对齐信息已保存到: {self.alignment_dir}")
print(f"所有中间文件保存在: {self.output_dir}")
return results
def _load_wespeaker_model(self, wespeaker_model_dir):
"""加载wespeaker模型"""
try:
import wespeaker
# 使用load_model_local方法加载本地模型
# 根据你提供的参考,使用你指定的模型路径
local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
try:
self.similarity_model = wespeaker.load_model_local(local_model_path)
self.logger.info(f"成功加载本地wespeaker模型: {local_model_path}")
return
except Exception as e:
self.logger.warning(f"加载指定本地模型失败: {e}")
# 回退方案1: 尝试使用传入的模型目录
if os.path.exists(wespeaker_model_dir):
try:
self.similarity_model = wespeaker.load_model_local(wespeaker_model_dir)
self.logger.info(f"成功加载传入的本地wespeaker模型: {wespeaker_model_dir}")
return
except Exception as e:
self.logger.warning(f"加载传入本地模型失败: {e}")
# 回退方案2: 使用预训练的中文模型
try:
self.similarity_model = wespeaker.load_model('chinese')
self.logger.info("回退到wespeaker预训练中文模型")
return
except Exception as e:
self.logger.warning(f"加载预训练中文模型失败: {e}")
# 回退方案3: 使用预训练的英文模型
try:
self.similarity_model = wespeaker.load_model('english')
self.logger.info("回退到wespeaker预训练英文模型")
return
except Exception as e:
self.logger.error(f"加载英文模型也失败: {e}")
# 如果所有方法都失败,抛出异常
raise Exception("无法加载任何wespeaker模型")
except ImportError:
raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git")
except Exception as e:
self.logger.error(f"加载wespeaker模型失败: {e}")
raise
def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]:
"""从JSONL文件加载数据"""
data = []
try:
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if line:
try:
item = json.loads(line)
# 验证必要字段
required_fields = ['text', 'output_audio']
for field in required_fields:
if field not in item:
self.logger.error(f"第{line_num}行缺少必要字段: {field}")
continue
# 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件
has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item
has_separate_prompts = ('prompt_audio_speaker1' in item and
'prompt_text_speaker1' in item and
'prompt_audio_speaker2' in item and
'prompt_text_speaker2' in item)
if not (has_combined_prompt or has_separate_prompts):
self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件")
continue
data.append(item)
except json.JSONDecodeError as e:
self.logger.error(f"第{line_num}行JSON解析错误: {e}")
continue
self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据")
return data
except FileNotFoundError:
self.logger.error(f"JSONL文件不存在: {jsonl_path}")
return []
except Exception as e:
self.logger.error(f"读取JSONL文件失败: {e}")
return []
@staticmethod
def get_gpu_count():
"""获取可用GPU数量"""
if torch.cuda.is_available():
return torch.cuda.device_count()
return 0
@staticmethod
def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]:
"""根据GPU数量分割数据"""
if num_gpus == 0:
return [data]
chunk_size = math.ceil(len(data) / num_gpus)
gpu_chunks = []
for i in range(num_gpus):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(data))
if start_idx < len(data):
gpu_chunks.append(data[start_idx:end_idx])
return gpu_chunks
@staticmethod
def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]:
"""根据进程数量分割数据"""
if num_processes <= 1:
return [data]
chunk_size = math.ceil(len(data) / num_processes)
process_chunks = []
for i in range(num_processes):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(data))
if start_idx < len(data):
process_chunks.append(data[start_idx:end_idx])
return process_chunks
def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str):
"""增量写入结果到JSONL文件"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
f.flush() # 强制刷新缓冲区
def merge_temp_results(self, temp_files: List[str], final_path: str):
"""合并临时结果文件"""
all_results = []
for temp_file in temp_files:
if os.path.exists(temp_file):
try:
with open(temp_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
result = json.loads(line)
all_results.append(result)
except Exception as e:
self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}")
# 写入最终文件
with open(final_path, 'w', encoding='utf-8') as f:
for result in all_results:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
return all_results
def process_batch_parallel(self, input_data: List[Dict[str, Any]],
processes_per_gpu: int = 8, # 降低进程数
results_filename: str = None,
shuffle_data: bool = True):
"""并行批量处理输入数据"""
# 1. 检查GPU数量
num_gpus = self.get_gpu_count()
if num_gpus == 0:
self.logger.warning("未检测到GPU,将使用CPU单进程处理")
return self.process_batch_from_data(input_data, results_filename)
# 限制每个GPU的进程数,避免CUDA内存冲突
max_processes_per_gpu = min(processes_per_gpu, 16)
self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程")
# 2. 对数据进行shuffle(如果还没有shuffle过)
shuffled_data = input_data.copy()
if shuffle_data:
random.shuffle(shuffled_data)
self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载")
# 3. 按GPU分割数据
gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus)
# 打印每个GPU分配到的数据量
for gpu_id, gpu_data in enumerate(gpu_chunks):
if gpu_data:
self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据")
# 4. 准备结果文件路径
if results_filename is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
final_results_path = self.results_dir / results_filename
# 5. 为所有GPU准备进程参数
all_temp_files = []
all_gpu_tasks = []
for gpu_id, gpu_data in enumerate(gpu_chunks):
if not gpu_data:
continue
self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据")
# 按进程数分割当前GPU的数据
process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu)
# 为当前GPU准备所有进程参数
gpu_process_args = []
for proc_id, proc_data in enumerate(process_chunks):
if proc_data:
temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl")
all_temp_files.append(temp_result_file)
# 子进程输出目录在主输出目录内部
subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}")
gpu_process_args.append((
proc_data,
gpu_id,
proc_id,
subprocess_output_dir,
temp_result_file,
self.alignment_model_dir,
self.wespeaker_model_dir,
self.language, # 语言参数
self.similarity_max_workers # 添加相似度计算线程数参数
))
if gpu_process_args:
all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu))
# 6. 使用ThreadPoolExecutor并行处理所有GPU
def process_gpu_tasks(gpu_task):
gpu_id, process_args, actual_processes = gpu_task
self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程")
# 为每个GPU使用独立的进程池,避免进程间冲突
with mp.Pool(processes=actual_processes) as pool:
pool.map(process_data_chunk_incremental, process_args)
self.logger.info(f"GPU {gpu_id}: 所有进程处理完成")
return gpu_id
# 使用线程池同时处理所有GPU
with ThreadPoolExecutor(max_workers=num_gpus) as executor:
# 提交所有GPU任务
future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0]
for gpu_task in all_gpu_tasks}
# 等待所有GPU完成
completed_gpus = []
for future in as_completed(future_to_gpu):
gpu_id = future_to_gpu[future]
try:
result_gpu_id = future.result()
completed_gpus.append(result_gpu_id)
self.logger.info(f"GPU {result_gpu_id} 完成处理")
except Exception as exc:
self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}")
self.logger.info(f"所有GPU处理完成: {completed_gpus}")
# 7. 合并所有临时结果文件
self.logger.info("合并所有临时结果文件...")
all_results = self.merge_temp_results(all_temp_files, str(final_results_path))
if not all_results:
self.logger.error("没有成功处理的数据")
return []
# 8. 生成汇总报告
summary_path = self.save_summary_report(all_results)
# 9. 清理临时文件
for temp_file in all_temp_files:
if os.path.exists(temp_file):
os.remove(temp_file)
# 10. 打印总体统计
total_accuracy = np.mean([r['accuracy'] for r in all_results])
total_avg_similarity = np.mean([r['average_similarity'] for r in all_results])
print(f"\n=== 并行评估完成 ===")
print(f"使用语言: {self.language}")
print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程")
print(f"总处理数据: {len(input_data)} 条")
print(f"成功处理: {len(all_results)} 条")
print(f"总体准确率: {total_accuracy:.3f}")
print(f"总体平均相似度: {total_avg_similarity:.3f}")
print(f"详细结果已保存到: {final_results_path}")
print(f"汇总报告已保存到: {summary_path}")
print(f"对齐信息已保存到: {self.alignment_dir}")
return all_results
def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]:
"""
获取或分割prompt音频
如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割
"""
# 检查是否有分别的speaker音频文件
if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and
'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data):
self.logger.info(f"使用预分割的speaker音频文件")
# 即使使用预分割的音频,也保存对齐信息
try:
# 检测语言或使用设置的语言
alignment_language = self.language
if alignment_language == "AUTO":
alignment_language = self._detect_language_from_text(data['prompt_text_speaker1'])
# 对S1音频进行对齐
s1_alignments = self.align_text_with_audio(
data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language
)
s1_alignment_data = {
'speaker': 'S1',
'text': data['prompt_text_speaker1'],
'audio_path': data['prompt_audio_speaker1'],
'language': alignment_language,
'alignments': s1_alignments
}
self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1")
# 对S2音频进行对齐
s2_alignments = self.align_text_with_audio(
data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language
)
s2_alignment_data = {
'speaker': 'S2',
'text': data['prompt_text_speaker2'],
'audio_path': data['prompt_audio_speaker2'],
'language': alignment_language,
'alignments': s2_alignments
}
self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2")
except Exception as e:
self.logger.warning(f"保存预分割音频对齐信息失败: {e}")
return data['prompt_audio_speaker1'], data['prompt_audio_speaker2']
# 否则从combined prompt分割
elif 'prompt_audio' in data and 'prompt_text' in data:
self.logger.info(f"从combined prompt音频分割speaker片段")
return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id)
else:
raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件")
def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float:
"""
计算两个音频的音色相似度(向后兼容版本)
对于过短的音频片段,通过复制来达到最小长度要求
"""
# 如果在多线程环境中,使用线程安全版本
if threading.current_thread() != threading.main_thread():
return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path)
# 确保模型已初始化
self._init_models_if_needed()
try:
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
return None
# 检查并处理音频文件长度
def process_audio_for_similarity(audio_path, min_duration=0.1):
"""
处理音频文件,如果过短则复制到满足最小长度要求
返回处理后的音频路径和是否为临时文件的标志
"""
try:
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration >= min_duration:
# 音频长度足够,直接返回原路径
return audio_path, False
# 音频过短,需要复制
repeat_times = math.ceil(min_duration / duration)
self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}")
# 复制音频
repeated_waveform = waveform.repeat(1, repeat_times)
# 生成临时文件路径
temp_filename = f"temp_{os.path.basename(audio_path)}"
temp_path = str(self.temp_dir / temp_filename)
# 保存复制后的音频
torchaudio.save(temp_path, repeated_waveform, sample_rate)
return temp_path, True
except Exception as e:
self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
return audio_path, False
# 处理两个音频文件
processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
# 计算相似度
similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
# 清理临时文件
if is_temp1 and os.path.exists(processed_audio1):
try:
os.remove(processed_audio1)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
if is_temp2 and os.path.exists(processed_audio2):
try:
os.remove(processed_audio2)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
return float(similarity)
except Exception as e:
# 检查是否是窗口大小错误或其他计算错误
if "choose a window size" in str(e) or "window size" in str(e):
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
return None
else:
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
return None
# 全局函数,用于多进程处理(支持增量写入)
def process_data_chunk_incremental(args):
"""处理数据块的工作函数(增量写入版本)"""
data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_dir, language, similarity_max_workers = args
# 设置当前进程使用的GPU
device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu"
try:
# 清理CUDA状态,避免进程间冲突
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 设置当前进程的GPU设备
torch.cuda.set_device(gpu_id)
# 添加小延迟,避免同时初始化冲突
time.sleep(proc_id * 0.5)
# 创建评估器实例,传入模型路径、语言参数和相似度计算线程数
evaluator = SpeakerSimilarityEvaluator(
device=device,
alignment_model_dir=alignment_model_dir,
wespeaker_model_dir=wespeaker_model_dir,
output_dir=output_dir,
language=language, # 传入语言参数
similarity_max_workers=similarity_max_workers # 传入相似度计算线程数
)
# 延迟初始化模型
evaluator._init_models_if_needed()
# 清空临时结果文件(如果存在)
if os.path.exists(temp_result_file):
os.remove(temp_result_file)
# 处理数据块
for i, data in enumerate(data_chunk):
input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}"
try:
result = evaluator.evaluate_single_input(data, input_id=input_id)
# 立即写入结果到临时文件
evaluator.append_result_to_jsonl(result, temp_result_file)
print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})")
# 每处理完一个数据项,清理CUDA缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}")
# 出错时也清理CUDA缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}")
except Exception as e:
print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}")
# 出错时清理CUDA缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
def main():
"""主函数示例"""
import argparse
parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator')
parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径')
parser.add_argument('--output_dir', type=str,
default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
help='结果保存目录')
parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh',
help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)')
parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)')
parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)')
parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)')
parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)')
parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)')
args = parser.parse_args()
# 设置随机种子(如果指定)
if args.random_seed is not None:
random.seed(args.random_seed)
np.random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
print(f"设置随机种子: {args.random_seed}")
# 语言参数处理
language = args.language.upper()
if language == 'AUTO':
language = 'AUTO'
elif language == 'EN':
language = 'EN'
else:
language = 'ZH' # 默认中文
# 创建评估器,指定结果保存目录、语言和相似度计算线程数
evaluator = SpeakerSimilarityEvaluator(
output_dir=args.output_dir,
language=language,
similarity_max_workers=args.similarity_workers
)
# 默认使用并行处理,除非明确禁用
use_parallel = not args.no_parallel
use_shuffle = not args.no_shuffle
print(f"使用语言设置: {language}")
print(f"相似度计算线程数: {args.similarity_workers}")
if args.jsonl_path:
# 从JSONL文件处理数据
if use_parallel:
evaluator.process_batch_from_jsonl_parallel(
args.jsonl_path,
processes_per_gpu=args.processes_per_gpu,
shuffle_data=use_shuffle
)
else:
evaluator.process_batch_from_jsonl(args.jsonl_path)
else:
# 使用示例数据(兼容性)
input_data = [
{
'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav",
'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。",
'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。",
'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav"
}
]
# 处理数据
if use_parallel:
evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu)
else:
evaluator.process_batch_from_data(input_data)
if __name__ == "__main__":
main()