| | import json |
| | import re |
| | import os |
| | from typing import List, Dict, Tuple, Any |
| | import numpy as np |
| | from pathlib import Path |
| | import torch |
| | import torchaudio |
| | import torchaudio.functional as F |
| | import logging |
| | import wespeaker |
| | import shutil |
| | from datetime import datetime |
| | import multiprocessing as mp |
| | from functools import partial |
| | import math |
| | import threading |
| | import time |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | import random |
| |
|
| | |
| | mp.set_start_method('spawn', force=True) |
| |
|
| | |
| | from alignment import AlignmentModel, batch_get_alignment_result |
| |
|
| | class SpeakerSimilarityEvaluator: |
| | """音色相似度评估器""" |
| | |
| | def __init__(self, device="cuda", |
| | alignment_model_dir='/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa', |
| | wespeaker_model_dir='/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft', |
| | output_dir="./evaluation_results", |
| | language="ZH", |
| | similarity_max_workers=8): |
| | """初始化评估器""" |
| | self.device = device |
| | self.alignment_model_dir = alignment_model_dir |
| | self.wespeaker_model_dir = wespeaker_model_dir |
| | self.language = language.upper() |
| | self.similarity_max_workers = similarity_max_workers |
| | |
| | |
| | logging.basicConfig(level=logging.INFO) |
| | self.logger = logging.getLogger(__name__) |
| | |
| | |
| | self.output_dir = Path(output_dir) |
| | self.segments_dir = self.output_dir / "segments" |
| | self.prompts_dir = self.output_dir / "prompts" |
| | self.temp_dir = self.output_dir / "temp" |
| | self.results_dir = self.output_dir / "results" |
| | self.temp_results_dir = self.output_dir / "temp_results" |
| | self.alignment_dir = self.output_dir / "alignments" |
| | |
| | |
| | self._create_output_directories() |
| | |
| | |
| | self.alignment_model = None |
| | self.similarity_model = None |
| | |
| | |
| | self._thread_local = threading.local() |
| | |
| | |
| | self.logger.info(f"评估结果将保存到: {self.output_dir}") |
| | self.logger.info(f"对齐信息将保存到: {self.alignment_dir}") |
| | self.logger.info(f"使用语言: {self.language}") |
| | |
| | def _create_output_directories(self): |
| | """创建输出目录结构""" |
| | for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir, |
| | self.results_dir, self.temp_results_dir, self.alignment_dir]: |
| | dir_path.mkdir(parents=True, exist_ok=True) |
| | |
| | def _get_safe_filename(self, text: str, max_length: int = 50) -> str: |
| | """生成安全的文件名""" |
| | |
| | safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text) |
| | |
| | if len(safe_text) > max_length: |
| | safe_text = safe_text[:max_length] |
| | |
| | safe_text = safe_text.replace(' ', '_') |
| | return safe_text if safe_text else "unnamed" |
| | |
| | def _clean_temp_files(self): |
| | """清理临时文件,但保留临时目录""" |
| | if self.temp_dir.exists(): |
| | |
| | for file_path in self.temp_dir.iterdir(): |
| | if file_path.is_file(): |
| | try: |
| | file_path.unlink() |
| | except Exception as e: |
| | self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}") |
| | else: |
| | |
| | self.temp_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | def _init_models_if_needed(self): |
| | """延迟初始化模型(用于多进程环境)""" |
| | |
| | if self.alignment_model is None: |
| | |
| | self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir) |
| | |
| | |
| | if self.similarity_model is None: |
| | self._load_wespeaker_model(self.wespeaker_model_dir) |
| | |
| | def _is_english_text(self, text: str) -> bool: |
| | """简单判断文本是否主要是英文""" |
| | |
| | english_chars = sum(1 for c in text if c.isascii() and c.isalpha()) |
| | total_chars = sum(1 for c in text if c.isalpha()) |
| | |
| | if total_chars == 0: |
| | return False |
| | |
| | return english_chars / total_chars > 0.8 |
| | |
| | def _detect_language_from_text(self, text: str) -> str: |
| | """从文本内容检测语言""" |
| | clean_text = self.remove_speaker_tags(text) |
| | if self._is_english_text(clean_text): |
| | return "EN" |
| | else: |
| | return "ZH" |
| | |
| | def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"): |
| | """ |
| | 保存对齐信息到单独的JSON文件 |
| | |
| | Args: |
| | alignment_data: 对齐信息数据 |
| | input_id: 输入ID |
| | file_type: 文件类型 ("output", "prompt", "segment") |
| | """ |
| | try: |
| | safe_input_id = self._get_safe_filename(input_id) |
| | alignment_filename = f"{safe_input_id}_{file_type}_alignment.json" |
| | alignment_path = self.alignment_dir / alignment_filename |
| | |
| | |
| | alignment_info = { |
| | 'input_id': input_id, |
| | 'file_type': file_type, |
| | 'language': self.language, |
| | 'timestamp': datetime.now().isoformat(), |
| | 'alignment_data': alignment_data |
| | } |
| | |
| | with open(alignment_path, 'w', encoding='utf-8') as f: |
| | json.dump(alignment_info, f, ensure_ascii=False, indent=2) |
| | |
| | self.logger.info(f"对齐信息已保存: {alignment_path}") |
| | return str(alignment_path) |
| | |
| | except Exception as e: |
| | self.logger.error(f"保存对齐信息失败: {e}") |
| | return None |
| | |
| | def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]], |
| | text_segments: List[Dict[str, Any]], |
| | input_id: str, audio_path: str, |
| | original_text: str, processed_text: str): |
| | """ |
| | 保存详细的对齐信息,包括分段信息 |
| | |
| | Args: |
| | alignments: 对齐结果列表 |
| | text_segments: 文本分段信息 |
| | input_id: 输入ID |
| | audio_path: 音频文件路径 |
| | original_text: 原始文本 |
| | processed_text: 处理后的文本 |
| | """ |
| | alignment_data = { |
| | 'original_text': original_text, |
| | 'processed_text': processed_text, |
| | 'audio_path': audio_path, |
| | 'language': self.language, |
| | 'total_alignments': len(alignments), |
| | 'total_segments': len(text_segments), |
| | 'alignments': alignments, |
| | 'text_segments': text_segments, |
| | 'segment_alignment_mapping': [] |
| | } |
| | |
| | |
| | for segment in text_segments: |
| | segment_mapping = { |
| | 'segment_id': segment.get('segment_id', 0), |
| | 'segment_text': segment.get('text', ''), |
| | 'speaker_label': segment.get('speaker_label', ''), |
| | 'start_time': segment.get('start_time', 0.0), |
| | 'end_time': segment.get('end_time', 0.0), |
| | 'corresponding_alignments': [] |
| | } |
| | |
| | |
| | segment_start = segment.get('start_time', 0.0) |
| | segment_end = segment.get('end_time', 0.0) |
| | |
| | for i, align_item in enumerate(alignments): |
| | align_start = align_item.get('start', 0.0) |
| | align_end = align_item.get('end', 0.0) |
| | |
| | |
| | if (align_start >= segment_start and align_end <= segment_end) or \ |
| | (align_start < segment_end and align_end > segment_start): |
| | segment_mapping['corresponding_alignments'].append({ |
| | 'alignment_index': i, |
| | 'transcript': align_item.get('transcript', ''), |
| | 'start': align_start, |
| | 'end': align_end, |
| | 'score': align_item.get('score', 0.0) if 'score' in align_item else None |
| | }) |
| | |
| | alignment_data['segment_alignment_mapping'].append(segment_mapping) |
| | |
| | return self.save_alignment_info(alignment_data, input_id, "detailed") |
| | |
| | def remove_speaker_tags(self, text: str) -> str: |
| | """删除文本中的说话人标签[S1][S2]""" |
| | return re.sub(r'\[S[12]\]', '', text).strip() |
| | |
| | def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]: |
| | """提取文本中的说话人片段信息""" |
| | segments = [] |
| | pattern = r'\[S([12])\]([^[]*)' |
| | matches = re.findall(pattern, text) |
| | |
| | for speaker_id, content in matches: |
| | segments.append({ |
| | 'speaker': f'S{speaker_id}', |
| | 'content': content.strip() |
| | }) |
| | return segments |
| | |
| | def replace_punctuation_with_comma(self, text: str, language: str = None) -> str: |
| | """将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型""" |
| | |
| | if language is None: |
| | if hasattr(self, 'language'): |
| | language = self.language |
| | else: |
| | language = self._detect_language_from_text(text) |
| | |
| | language = language.upper() |
| | |
| | |
| | if language == "EN" or (language == "AUTO" and self._is_english_text(text)): |
| | |
| | text = re.sub(r"'", '', text) |
| | target_comma = ',' |
| | comma_pattern = r',+' |
| | |
| | text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text) |
| | else: |
| | |
| | target_comma = ',' |
| | comma_pattern = r',+' |
| | |
| | text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text) |
| | |
| | text = re.sub(comma_pattern, target_comma, text) |
| | return text.strip(target_comma) |
| | |
| | def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]: |
| | """ |
| | 文本和音频的词对齐 |
| | 返回每个词对应的音频时间段 |
| | """ |
| | |
| | self._init_models_if_needed() |
| | |
| | |
| | if language is None: |
| | if hasattr(self, 'language'): |
| | language = self.language |
| | else: |
| | language = self._detect_language_from_text(text) |
| | else: |
| | language = language.upper() |
| | |
| | |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| | |
| | |
| | if sample_rate != self.alignment_model.bundle.sample_rate: |
| | waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate) |
| | |
| | |
| | if waveform.shape[0] > 1: |
| | waveform = torch.mean(waveform, dim=0, keepdim=True) |
| | |
| | waveform = waveform.squeeze(0) |
| | |
| | |
| | waveform = waveform.to(self.device) |
| | |
| | |
| | try: |
| | alignment_results = batch_get_alignment_result( |
| | self.alignment_model, |
| | [waveform], |
| | [text], |
| | [language] |
| | ) |
| | if not alignment_results or not alignment_results[0]: |
| | raise RuntimeError(f"对齐结果为空: {audio_path}") |
| | return alignment_results[0] |
| | except Exception as e: |
| | self.logger.error(f"音频对齐失败: {audio_path}") |
| | self.logger.error(f"错误详情: {e}") |
| | raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}") |
| | |
| | def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str): |
| | """分割音频片段""" |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| | |
| | start_frame = int(start_time * sample_rate) |
| | end_frame = int(end_time * sample_rate) |
| | |
| | segment = waveform[:, start_frame:end_frame] |
| | |
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | |
| | torchaudio.save(output_path, segment, sample_rate) |
| | return output_path |
| | |
| | def concatenate_audio_files(self, audio_files: List[str], output_path: str): |
| | """拼接多个音频文件""" |
| | if not audio_files: |
| | return |
| | |
| | waveforms = [] |
| | sample_rate = None |
| | |
| | for audio_file in audio_files: |
| | if os.path.exists(audio_file): |
| | waveform, sr = torchaudio.load(audio_file) |
| | if sample_rate is None: |
| | sample_rate = sr |
| | elif sr != sample_rate: |
| | waveform = F.resample(waveform, sr, sample_rate) |
| | waveforms.append(waveform) |
| | |
| | if waveforms: |
| | concatenated = torch.cat(waveforms, dim=1) |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | torchaudio.save(output_path, concatenated, sample_rate) |
| | |
| | def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]: |
| | """ |
| | 根据说话人标签分割prompt音频 |
| | 返回S1和S2的音频片段路径 |
| | """ |
| | |
| | speaker_segments = self.extract_speaker_segments(prompt_text) |
| | |
| | |
| | clean_text = self.remove_speaker_tags(prompt_text) |
| | |
| | |
| | alignment_language = self.language |
| | if alignment_language == "AUTO": |
| | alignment_language = self._detect_language_from_text(clean_text) |
| | |
| | alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language) |
| | |
| | |
| | prompt_alignment_data = { |
| | 'original_text': prompt_text, |
| | 'clean_text': clean_text, |
| | 'audio_path': prompt_audio, |
| | 'language': alignment_language, |
| | 'speaker_segments': speaker_segments, |
| | 'alignments': alignments |
| | } |
| | self.save_alignment_info(prompt_alignment_data, audio_id, "prompt") |
| | |
| | |
| | s1_segments = [] |
| | s2_segments = [] |
| | |
| | |
| | text_pos = 0 |
| | for seg in speaker_segments: |
| | seg_text = seg['content'].strip() |
| | seg_length = len(seg_text) |
| | |
| | |
| | start_time = None |
| | end_time = None |
| | |
| | current_pos = 0 |
| | for align_item in alignments: |
| | item_text = align_item['transcript'] |
| | item_length = len(item_text) |
| | |
| | if current_pos >= text_pos and current_pos < text_pos + seg_length: |
| | if start_time is None: |
| | start_time = align_item['start'] |
| | end_time = align_item['end'] |
| | |
| | current_pos += item_length |
| | |
| | if start_time is not None and end_time is not None: |
| | if seg['speaker'] == 'S1': |
| | s1_segments.append((start_time, end_time)) |
| | else: |
| | s2_segments.append((start_time, end_time)) |
| | |
| | text_pos += seg_length |
| | |
| | |
| | safe_audio_id = self._get_safe_filename(audio_id) |
| | prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav") |
| | prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav") |
| | |
| | |
| | if s1_segments: |
| | s1_temp_segments = [] |
| | for i, (start, end) in enumerate(s1_segments): |
| | temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav") |
| | self.split_audio_segment(prompt_audio, start, end, temp_path) |
| | s1_temp_segments.append(temp_path) |
| | |
| | |
| | self.concatenate_audio_files(s1_temp_segments, prompts1_path) |
| | |
| | |
| | if s2_segments: |
| | s2_temp_segments = [] |
| | for i, (start, end) in enumerate(s2_segments): |
| | temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav") |
| | self.split_audio_segment(prompt_audio, start, end, temp_path) |
| | s2_temp_segments.append(temp_path) |
| | |
| | |
| | self.concatenate_audio_files(s2_temp_segments, prompts2_path) |
| | |
| | return prompts1_path, prompts2_path |
| | |
| | def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]: |
| | """ |
| | 将原始文本按说话人和标点符号同时分割,保持映射关系 |
| | 支持英文单词级别的处理 |
| | """ |
| | segments = [] |
| | pattern = r'\[S([12])\]([^[]*)' |
| | matches = re.findall(pattern, original_text) |
| | |
| | |
| | alignment_language = self.language |
| | if alignment_language == "AUTO": |
| | alignment_language = self._detect_language_from_text(original_text) |
| | |
| | segment_id = 0 |
| | for speaker_id, content in matches: |
| | speaker = f'S{speaker_id}' |
| | clean_content = content.strip() |
| | comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language) |
| | |
| | |
| | if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)): |
| | |
| | parts = [part.strip() for part in comma_content.split(',') if part.strip()] |
| | else: |
| | |
| | parts = [part.strip() for part in comma_content.split(',') if part.strip()] |
| | |
| | for part in parts: |
| | if part.strip(): |
| | segments.append({ |
| | 'segment_id': segment_id, |
| | 'text': part.strip(), |
| | 'speaker_label': speaker, |
| | 'original_speaker_content': clean_content |
| | }) |
| | segment_id += 1 |
| | |
| | return segments |
| | |
| | def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]: |
| | """ |
| | 根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子 |
| | """ |
| | |
| | text_segments = self.map_text_segments_to_speakers(text) |
| | |
| | |
| | clean_text = self.remove_speaker_tags(text) |
| | |
| | |
| | alignment_language = self.language |
| | if alignment_language == "AUTO": |
| | alignment_language = self._detect_language_from_text(clean_text) |
| | |
| | |
| | comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language) |
| | |
| | |
| | alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language) |
| | |
| | |
| | segments = [] |
| | safe_audio_id = self._get_safe_filename(audio_id) |
| | |
| | |
| | if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): |
| | punctuation_chars = set([',', '.', '!', '?', ';', ':']) |
| | else: |
| | punctuation_chars = set([',', '。', '!', '?', ';', ':']) |
| | |
| | |
| | sentence_start_idx = 0 |
| | sentence_alignments = [] |
| | segment_id = 0 |
| | |
| | for i, align_item in enumerate(alignments): |
| | transcript = align_item['transcript'] |
| | sentence_alignments.append(align_item) |
| | |
| | |
| | has_punctuation = any(punct in transcript for punct in punctuation_chars) |
| | |
| | if has_punctuation or i == len(alignments) - 1: |
| | |
| | if sentence_alignments: |
| | |
| | start_time = sentence_alignments[0]['start'] |
| | end_time = sentence_alignments[-1]['end'] |
| | |
| | |
| | sentence_text_parts = [] |
| | for align in sentence_alignments: |
| | |
| | if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): |
| | |
| | clean_transcript = align['transcript'].rstrip(',.!?;:') |
| | else: |
| | |
| | clean_transcript = align['transcript'].rstrip(',。!?;:') |
| | |
| | if clean_transcript.strip(): |
| | sentence_text_parts.append(clean_transcript) |
| | |
| | |
| | if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): |
| | sentence_text = ' '.join(sentence_text_parts).strip() |
| | else: |
| | sentence_text = ''.join(sentence_text_parts).strip() |
| | |
| | if sentence_text: |
| | |
| | speaker_label = "S1" |
| | if segment_id < len(text_segments): |
| | speaker_label = text_segments[segment_id]['speaker_label'] |
| | elif text_segments: |
| | |
| | speaker_label = text_segments[-1]['speaker_label'] |
| | |
| | |
| | safe_text = self._get_safe_filename(sentence_text, 30) |
| | audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav") |
| | |
| | |
| | try: |
| | self.split_audio_segment(output_audio, start_time, end_time, audio_path) |
| | except Exception as e: |
| | self.logger.error(f"分割音频失败: {e}") |
| | |
| | start_time = segment_id * 1.0 |
| | end_time = (segment_id + 1) * 1.0 |
| | self.split_audio_segment(output_audio, start_time, end_time, audio_path) |
| | |
| | |
| | segment = { |
| | 'segment_id': segment_id, |
| | 'text': sentence_text, |
| | 'speaker_label': speaker_label, |
| | 'original_speaker_content': sentence_text, |
| | 'audio_path': audio_path, |
| | 'start_time': start_time, |
| | 'end_time': end_time |
| | } |
| | |
| | segments.append(segment) |
| | |
| | self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s") |
| | segment_id += 1 |
| | |
| | |
| | sentence_alignments = [] |
| | sentence_start_idx = i + 1 |
| | |
| | |
| | self.save_detailed_alignment_info( |
| | alignments, segments, audio_id, output_audio, text, comma_text |
| | ) |
| | |
| | self.logger.info(f"总共分割出 {len(segments)} 个句子片段") |
| | return segments |
| | |
| | def _get_thread_local_similarity_model(self): |
| | """获取线程局部的相似度模型实例(线程安全)""" |
| | if not hasattr(self._thread_local, 'similarity_model'): |
| | |
| | self._thread_local.similarity_model = self._create_similarity_model() |
| | return self._thread_local.similarity_model |
| | |
| | def _create_similarity_model(self): |
| | """创建新的相似度模型实例""" |
| | try: |
| | import wespeaker |
| | |
| | |
| | local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft' |
| | |
| | try: |
| | model = wespeaker.load_model_local(local_model_path) |
| | return model |
| | except Exception as e: |
| | self.logger.warning(f"加载指定本地模型失败: {e}") |
| | |
| | |
| | if os.path.exists(self.wespeaker_model_dir): |
| | try: |
| | model = wespeaker.load_model_local(self.wespeaker_model_dir) |
| | return model |
| | except Exception as e: |
| | self.logger.warning(f"加载传入本地模型失败: {e}") |
| | |
| | |
| | try: |
| | model = wespeaker.load_model('chinese') |
| | return model |
| | except Exception as e: |
| | model = wespeaker.load_model('english') |
| | return model |
| | |
| | except Exception as e: |
| | self.logger.error(f"创建相似度模型失败: {e}") |
| | raise |
| | |
| | def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float: |
| | """ |
| | 线程安全的音色相似度计算 |
| | 对于过短的音频片段,通过复制来达到最小长度要求 |
| | """ |
| | try: |
| | if not os.path.exists(audio1_path) or not os.path.exists(audio2_path): |
| | self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}") |
| | return None |
| | |
| | |
| | similarity_model = self._get_thread_local_similarity_model() |
| | |
| | |
| | def process_audio_for_similarity(audio_path, min_duration=0.1): |
| | """ |
| | 处理音频文件,如果过短则复制到满足最小长度要求 |
| | 返回处理后的音频路径和是否为临时文件的标志 |
| | """ |
| | try: |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| | duration = waveform.shape[1] / sample_rate |
| | |
| | if duration >= min_duration: |
| | |
| | return audio_path, False |
| | |
| | |
| | repeat_times = math.ceil(min_duration / duration) |
| | thread_id = threading.get_ident() |
| | |
| | |
| | repeated_waveform = waveform.repeat(1, repeat_times) |
| | |
| | |
| | temp_filename = f"temp_{thread_id}_{os.path.basename(audio_path)}" |
| | temp_path = str(self.temp_dir / temp_filename) |
| | |
| | |
| | torchaudio.save(temp_path, repeated_waveform, sample_rate) |
| | |
| | return temp_path, True |
| | |
| | except Exception as e: |
| | self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}") |
| | return audio_path, False |
| | |
| | |
| | processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path) |
| | processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path) |
| | |
| | |
| | similarity = similarity_model.compute_similarity(processed_audio1, processed_audio2) |
| | |
| | |
| | if is_temp1 and os.path.exists(processed_audio1): |
| | try: |
| | os.remove(processed_audio1) |
| | except Exception as e: |
| | self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}") |
| | |
| | if is_temp2 and os.path.exists(processed_audio2): |
| | try: |
| | os.remove(processed_audio2) |
| | except Exception as e: |
| | self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}") |
| | |
| | return float(similarity) |
| | |
| | except Exception as e: |
| | |
| | if "choose a window size" in str(e) or "window size" in str(e): |
| | self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}") |
| | return None |
| | else: |
| | self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}") |
| | return None |
| | |
| | def calculate_segment_similarities_parallel(self, output_segments: List[Dict[str, Any]], |
| | prompts1_path: str, prompts2_path: str) -> List[Dict[str, Any]]: |
| | """ |
| | 并行计算所有segments的相似度 |
| | Args: |
| | output_segments: 音频segments列表 |
| | prompts1_path: S1 prompt音频路径 |
| | prompts2_path: S2 prompt音频路径 |
| | Returns: |
| | 包含相似度信息的segment列表 |
| | """ |
| | |
| | def calculate_single_segment_similarity(segment): |
| | """计算单个segment与两个prompts的相似度""" |
| | try: |
| | |
| | sim1 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path) |
| | sim2 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path) |
| | |
| | return { |
| | 'segment': segment, |
| | 'sim1': sim1, |
| | 'sim2': sim2, |
| | 'success': True |
| | } |
| | except Exception as e: |
| | self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}") |
| | return { |
| | 'segment': segment, |
| | 'sim1': None, |
| | 'sim2': None, |
| | 'success': False |
| | } |
| | |
| | |
| | self.logger.info(f"开始并行计算 {len(output_segments)} 个segments的相似度,使用 {self.similarity_max_workers} 个线程") |
| | |
| | results = [] |
| | with ThreadPoolExecutor(max_workers=self.similarity_max_workers) as executor: |
| | |
| | future_to_segment = { |
| | executor.submit(calculate_single_segment_similarity, segment): segment |
| | for segment in output_segments |
| | } |
| | |
| | |
| | segment_to_result = {} |
| | completed_count = 0 |
| | for future in as_completed(future_to_segment): |
| | result = future.result() |
| | segment_id = result['segment']['segment_id'] |
| | segment_to_result[segment_id] = result |
| | completed_count += 1 |
| | |
| | |
| | if completed_count % 10 == 0 or completed_count == len(output_segments): |
| | self.logger.info(f"相似度计算进度: {completed_count}/{len(output_segments)}") |
| | |
| | |
| | for segment in output_segments: |
| | segment_id = segment['segment_id'] |
| | if segment_id in segment_to_result: |
| | results.append(segment_to_result[segment_id]) |
| | |
| | return results |
| |
|
| | def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]: |
| | """评估单个输入的音色相似度""" |
| | |
| | |
| | if input_id is None: |
| | input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | |
| | self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}") |
| | |
| | |
| | prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt") |
| | |
| | |
| | output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output") |
| | |
| | |
| | similarity_results = self.calculate_segment_similarities_parallel( |
| | output_segments, prompts1_path, prompts2_path |
| | ) |
| | |
| | |
| | segment_results = [] |
| | correct_predictions = 0 |
| | total_segments = 0 |
| | label_similarities = [] |
| | skipped_segments = 0 |
| | |
| | for sim_result in similarity_results: |
| | segment = sim_result['segment'] |
| | sim1 = sim_result['sim1'] |
| | sim2 = sim_result['sim2'] |
| | |
| | |
| | if sim1 is None or sim2 is None: |
| | skipped_segments += 1 |
| | self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败") |
| | continue |
| | |
| | |
| | total_segments += 1 |
| | |
| | |
| | predicted_speaker = 'S1' if sim1 > sim2 else 'S2' |
| | actual_speaker = segment['speaker_label'] |
| | is_correct = predicted_speaker == actual_speaker |
| | |
| | if is_correct: |
| | correct_predictions += 1 |
| | |
| | |
| | if actual_speaker == 'S1': |
| | label_similarity = sim1 |
| | else: |
| | label_similarity = sim2 |
| | label_similarities.append(label_similarity) |
| | |
| | segment_result = { |
| | 'segment_id': segment['segment_id'], |
| | 'text': segment['text'], |
| | 'speaker_label': actual_speaker, |
| | 'predicted_speaker': predicted_speaker, |
| | 'sim1': sim1, |
| | 'sim2': sim2, |
| | 'label_similarity': label_similarity, |
| | 'is_correct': is_correct, |
| | 'audio_path': segment['audio_path'], |
| | 'start_time': segment.get('start_time', 0.0), |
| | 'end_time': segment.get('end_time', 1.0) |
| | } |
| | segment_results.append(segment_result) |
| | |
| | |
| | accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0 |
| | average_similarity = np.mean(label_similarities) if label_similarities else 0.0 |
| | |
| | |
| | evaluation_alignment_summary = { |
| | 'input_id': input_id, |
| | 'language': self.language, |
| | 'prompt_alignment_files': [ |
| | f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json" |
| | ], |
| | 'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json", |
| | 'total_segments': total_segments, |
| | 'total_alignments_processed': len(output_segments), |
| | 'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0 |
| | } |
| | self.save_alignment_info(evaluation_alignment_summary, input_id, "summary") |
| | |
| | result = { |
| | 'input_id': input_id, |
| | 'language': self.language, |
| | 'input_data': data, |
| | 'prompts1_path': prompts1_path, |
| | 'prompts2_path': prompts2_path, |
| | 'segments': segment_results, |
| | 'accuracy': accuracy, |
| | 'average_similarity': average_similarity, |
| | 'total_segments': total_segments, |
| | 'correct_predictions': correct_predictions, |
| | 'skipped_segments': skipped_segments, |
| | 'original_total_segments': len(output_segments), |
| | 'alignment_files': { |
| | 'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json", |
| | 'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json", |
| | 'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json" |
| | }, |
| | 'timestamp': datetime.now().isoformat() |
| | } |
| | |
| | self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}") |
| | |
| | return result |
| |
|
| | def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None): |
| | """保存结果到JSONL文件""" |
| | if filename is None: |
| | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| | filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" |
| | |
| | output_path = self.results_dir / filename |
| | |
| | with open(output_path, 'w', encoding='utf-8') as f: |
| | for result in results: |
| | f.write(json.dumps(result, ensure_ascii=False) + '\n') |
| | |
| | return str(output_path) |
| | |
| | def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None): |
| | """保存汇总报告""" |
| | if filename is None: |
| | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| | filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json" |
| | |
| | summary_path = self.results_dir / filename |
| | |
| | |
| | total_accuracy = np.mean([r['accuracy'] for r in results]) |
| | total_avg_similarity = np.mean([r['average_similarity'] for r in results]) |
| | total_segments = sum([r['total_segments'] for r in results]) |
| | total_correct = sum([r['correct_predictions'] for r in results]) |
| | |
| | summary = { |
| | 'evaluation_summary': { |
| | 'language': self.language, |
| | 'total_inputs': len(results), |
| | 'total_segments': total_segments, |
| | 'total_correct_predictions': total_correct, |
| | 'overall_accuracy': total_accuracy, |
| | 'overall_average_similarity': total_avg_similarity, |
| | 'evaluation_timestamp': datetime.now().isoformat(), |
| | 'output_directory': str(self.output_dir), |
| | 'alignment_directory': str(self.alignment_dir) |
| | }, |
| | 'per_input_results': [ |
| | { |
| | 'input_id': r['input_id'], |
| | 'language': r.get('language', self.language), |
| | 'accuracy': r['accuracy'], |
| | 'average_similarity': r['average_similarity'], |
| | 'total_segments': r['total_segments'], |
| | 'correct_predictions': r['correct_predictions'], |
| | 'output_audio_path': r['input_data']['output_audio'], |
| | 'alignment_files': r.get('alignment_files', {}) |
| | } |
| | for r in results |
| | ] |
| | } |
| | |
| | with open(summary_path, 'w', encoding='utf-8') as f: |
| | json.dump(summary, f, ensure_ascii=False, indent=2) |
| | |
| | return str(summary_path) |
| | |
| | def process_batch_from_jsonl_parallel(self, jsonl_path: str, |
| | processes_per_gpu: int = 16, |
| | results_filename: str = None, |
| | shuffle_data: bool = True): |
| | """从JSONL文件并行批量处理输入数据""" |
| | |
| | input_data = self.load_data_from_jsonl(jsonl_path) |
| | |
| | if not input_data: |
| | self.logger.error("没有有效的输入数据") |
| | return [] |
| | |
| | |
| | if shuffle_data: |
| | random.shuffle(input_data) |
| | self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle") |
| | |
| | return self.process_batch_parallel(input_data, processes_per_gpu, results_filename) |
| |
|
| | def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None): |
| | """从JSONL文件批量处理输入数据(单进程版本)""" |
| | |
| | input_data = self.load_data_from_jsonl(jsonl_path) |
| | |
| | if not input_data: |
| | self.logger.error("没有有效的输入数据") |
| | return [] |
| | |
| | return self.process_batch_from_data(input_data, results_filename) |
| |
|
| | def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None): |
| | """处理数据列表(单进程版本,用于兼容),支持增量写入""" |
| | |
| | if results_filename is None: |
| | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| | results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" |
| | |
| | results_path = self.results_dir / results_filename |
| | |
| | |
| | if results_path.exists(): |
| | results_path.unlink() |
| | |
| | results = [] |
| | |
| | self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...") |
| | |
| | for i, data in enumerate(input_data): |
| | input_id = f"input_{i+1:03d}" |
| | print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}") |
| | |
| | try: |
| | result = self.evaluate_single_input(data, input_id=input_id) |
| | results.append(result) |
| | |
| | |
| | self.append_result_to_jsonl(result, str(results_path)) |
| | |
| | except Exception as e: |
| | self.logger.error(f"处理输入{input_id}时出错: {e}") |
| | continue |
| | |
| | if not results: |
| | self.logger.error("没有成功处理的输入") |
| | return [] |
| | |
| | |
| | summary_path = self.save_summary_report(results) |
| | |
| | |
| | self._clean_temp_files() |
| | |
| | |
| | total_accuracy = np.mean([r['accuracy'] for r in results]) |
| | total_avg_similarity = np.mean([r['average_similarity'] for r in results]) |
| | |
| | print(f"\n=== 评估完成 ===") |
| | print(f"使用语言: {self.language}") |
| | print(f"总体准确率: {total_accuracy:.3f}") |
| | print(f"总体平均相似度: {total_avg_similarity:.3f}") |
| | print(f"详细结果已保存到: {results_path}") |
| | print(f"汇总报告已保存到: {summary_path}") |
| | print(f"对齐信息已保存到: {self.alignment_dir}") |
| | print(f"所有中间文件保存在: {self.output_dir}") |
| | |
| | return results |
| |
|
| | def _load_wespeaker_model(self, wespeaker_model_dir): |
| | """加载wespeaker模型""" |
| | try: |
| | import wespeaker |
| | |
| | |
| | |
| | local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft' |
| | |
| | try: |
| | self.similarity_model = wespeaker.load_model_local(local_model_path) |
| | self.logger.info(f"成功加载本地wespeaker模型: {local_model_path}") |
| | return |
| | except Exception as e: |
| | self.logger.warning(f"加载指定本地模型失败: {e}") |
| | |
| | |
| | if os.path.exists(wespeaker_model_dir): |
| | try: |
| | self.similarity_model = wespeaker.load_model_local(wespeaker_model_dir) |
| | self.logger.info(f"成功加载传入的本地wespeaker模型: {wespeaker_model_dir}") |
| | return |
| | except Exception as e: |
| | self.logger.warning(f"加载传入本地模型失败: {e}") |
| | |
| | |
| | try: |
| | self.similarity_model = wespeaker.load_model('chinese') |
| | self.logger.info("回退到wespeaker预训练中文模型") |
| | return |
| | except Exception as e: |
| | self.logger.warning(f"加载预训练中文模型失败: {e}") |
| | |
| | |
| | try: |
| | self.similarity_model = wespeaker.load_model('english') |
| | self.logger.info("回退到wespeaker预训练英文模型") |
| | return |
| | except Exception as e: |
| | self.logger.error(f"加载英文模型也失败: {e}") |
| | |
| | |
| | raise Exception("无法加载任何wespeaker模型") |
| | |
| | except ImportError: |
| | raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git") |
| | except Exception as e: |
| | self.logger.error(f"加载wespeaker模型失败: {e}") |
| | raise |
| |
|
| | def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]: |
| | """从JSONL文件加载数据""" |
| | data = [] |
| | try: |
| | with open(jsonl_path, 'r', encoding='utf-8') as f: |
| | for line_num, line in enumerate(f, 1): |
| | line = line.strip() |
| | if line: |
| | try: |
| | item = json.loads(line) |
| | |
| | required_fields = ['text', 'output_audio'] |
| | for field in required_fields: |
| | if field not in item: |
| | self.logger.error(f"第{line_num}行缺少必要字段: {field}") |
| | continue |
| | |
| | |
| | has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item |
| | has_separate_prompts = ('prompt_audio_speaker1' in item and |
| | 'prompt_text_speaker1' in item and |
| | 'prompt_audio_speaker2' in item and |
| | 'prompt_text_speaker2' in item) |
| | |
| | if not (has_combined_prompt or has_separate_prompts): |
| | self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件") |
| | continue |
| | |
| | data.append(item) |
| | |
| | except json.JSONDecodeError as e: |
| | self.logger.error(f"第{line_num}行JSON解析错误: {e}") |
| | continue |
| | |
| | self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据") |
| | return data |
| | |
| | except FileNotFoundError: |
| | self.logger.error(f"JSONL文件不存在: {jsonl_path}") |
| | return [] |
| | except Exception as e: |
| | self.logger.error(f"读取JSONL文件失败: {e}") |
| | return [] |
| |
|
| | @staticmethod |
| | def get_gpu_count(): |
| | """获取可用GPU数量""" |
| | if torch.cuda.is_available(): |
| | return torch.cuda.device_count() |
| | return 0 |
| | |
| | @staticmethod |
| | def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]: |
| | """根据GPU数量分割数据""" |
| | if num_gpus == 0: |
| | return [data] |
| | |
| | chunk_size = math.ceil(len(data) / num_gpus) |
| | gpu_chunks = [] |
| | |
| | for i in range(num_gpus): |
| | start_idx = i * chunk_size |
| | end_idx = min((i + 1) * chunk_size, len(data)) |
| | if start_idx < len(data): |
| | gpu_chunks.append(data[start_idx:end_idx]) |
| | |
| | return gpu_chunks |
| | |
| | @staticmethod |
| | def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]: |
| | """根据进程数量分割数据""" |
| | if num_processes <= 1: |
| | return [data] |
| | |
| | chunk_size = math.ceil(len(data) / num_processes) |
| | process_chunks = [] |
| | |
| | for i in range(num_processes): |
| | start_idx = i * chunk_size |
| | end_idx = min((i + 1) * chunk_size, len(data)) |
| | if start_idx < len(data): |
| | process_chunks.append(data[start_idx:end_idx]) |
| | |
| | return process_chunks |
| |
|
| | def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str): |
| | """增量写入结果到JSONL文件""" |
| | os.makedirs(os.path.dirname(filepath), exist_ok=True) |
| | with open(filepath, 'a', encoding='utf-8') as f: |
| | f.write(json.dumps(result, ensure_ascii=False) + '\n') |
| | f.flush() |
| |
|
| | def merge_temp_results(self, temp_files: List[str], final_path: str): |
| | """合并临时结果文件""" |
| | all_results = [] |
| | |
| | for temp_file in temp_files: |
| | if os.path.exists(temp_file): |
| | try: |
| | with open(temp_file, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | line = line.strip() |
| | if line: |
| | result = json.loads(line) |
| | all_results.append(result) |
| | except Exception as e: |
| | self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}") |
| | |
| | |
| | with open(final_path, 'w', encoding='utf-8') as f: |
| | for result in all_results: |
| | f.write(json.dumps(result, ensure_ascii=False) + '\n') |
| | |
| | return all_results |
| |
|
| | def process_batch_parallel(self, input_data: List[Dict[str, Any]], |
| | processes_per_gpu: int = 8, |
| | results_filename: str = None, |
| | shuffle_data: bool = True): |
| | """并行批量处理输入数据""" |
| | |
| | num_gpus = self.get_gpu_count() |
| | if num_gpus == 0: |
| | self.logger.warning("未检测到GPU,将使用CPU单进程处理") |
| | return self.process_batch_from_data(input_data, results_filename) |
| | |
| | |
| | max_processes_per_gpu = min(processes_per_gpu, 16) |
| | self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程") |
| | |
| | |
| | shuffled_data = input_data.copy() |
| | if shuffle_data: |
| | random.shuffle(shuffled_data) |
| | self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载") |
| | |
| | |
| | gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus) |
| | |
| | |
| | for gpu_id, gpu_data in enumerate(gpu_chunks): |
| | if gpu_data: |
| | self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据") |
| | |
| | |
| | if results_filename is None: |
| | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
| | results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" |
| | |
| | final_results_path = self.results_dir / results_filename |
| | |
| | |
| | all_temp_files = [] |
| | all_gpu_tasks = [] |
| | |
| | for gpu_id, gpu_data in enumerate(gpu_chunks): |
| | if not gpu_data: |
| | continue |
| | |
| | self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据") |
| | |
| | |
| | process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu) |
| | |
| | |
| | gpu_process_args = [] |
| | for proc_id, proc_data in enumerate(process_chunks): |
| | if proc_data: |
| | temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl") |
| | all_temp_files.append(temp_result_file) |
| | |
| | |
| | subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}") |
| | |
| | gpu_process_args.append(( |
| | proc_data, |
| | gpu_id, |
| | proc_id, |
| | subprocess_output_dir, |
| | temp_result_file, |
| | self.alignment_model_dir, |
| | self.wespeaker_model_dir, |
| | self.language, |
| | self.similarity_max_workers |
| | )) |
| | |
| | if gpu_process_args: |
| | all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu)) |
| | |
| | |
| | def process_gpu_tasks(gpu_task): |
| | gpu_id, process_args, actual_processes = gpu_task |
| | self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程") |
| | |
| | |
| | with mp.Pool(processes=actual_processes) as pool: |
| | pool.map(process_data_chunk_incremental, process_args) |
| | |
| | self.logger.info(f"GPU {gpu_id}: 所有进程处理完成") |
| | return gpu_id |
| |
|
| | |
| | with ThreadPoolExecutor(max_workers=num_gpus) as executor: |
| | |
| | future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0] |
| | for gpu_task in all_gpu_tasks} |
| | |
| | |
| | completed_gpus = [] |
| | for future in as_completed(future_to_gpu): |
| | gpu_id = future_to_gpu[future] |
| | try: |
| | result_gpu_id = future.result() |
| | completed_gpus.append(result_gpu_id) |
| | self.logger.info(f"GPU {result_gpu_id} 完成处理") |
| | except Exception as exc: |
| | self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}") |
| | |
| | self.logger.info(f"所有GPU处理完成: {completed_gpus}") |
| | |
| | |
| | self.logger.info("合并所有临时结果文件...") |
| | all_results = self.merge_temp_results(all_temp_files, str(final_results_path)) |
| | |
| | if not all_results: |
| | self.logger.error("没有成功处理的数据") |
| | return [] |
| | |
| | |
| | summary_path = self.save_summary_report(all_results) |
| | |
| | |
| | for temp_file in all_temp_files: |
| | if os.path.exists(temp_file): |
| | os.remove(temp_file) |
| | |
| | |
| | total_accuracy = np.mean([r['accuracy'] for r in all_results]) |
| | total_avg_similarity = np.mean([r['average_similarity'] for r in all_results]) |
| | |
| | print(f"\n=== 并行评估完成 ===") |
| | print(f"使用语言: {self.language}") |
| | print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程") |
| | print(f"总处理数据: {len(input_data)} 条") |
| | print(f"成功处理: {len(all_results)} 条") |
| | print(f"总体准确率: {total_accuracy:.3f}") |
| | print(f"总体平均相似度: {total_avg_similarity:.3f}") |
| | print(f"详细结果已保存到: {final_results_path}") |
| | print(f"汇总报告已保存到: {summary_path}") |
| | print(f"对齐信息已保存到: {self.alignment_dir}") |
| | |
| | return all_results |
| |
|
| | def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]: |
| | """ |
| | 获取或分割prompt音频 |
| | 如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割 |
| | """ |
| | |
| | if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and |
| | 'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data): |
| | |
| | self.logger.info(f"使用预分割的speaker音频文件") |
| | |
| | |
| | try: |
| | |
| | alignment_language = self.language |
| | if alignment_language == "AUTO": |
| | alignment_language = self._detect_language_from_text(data['prompt_text_speaker1']) |
| | |
| | |
| | s1_alignments = self.align_text_with_audio( |
| | data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language |
| | ) |
| | s1_alignment_data = { |
| | 'speaker': 'S1', |
| | 'text': data['prompt_text_speaker1'], |
| | 'audio_path': data['prompt_audio_speaker1'], |
| | 'language': alignment_language, |
| | 'alignments': s1_alignments |
| | } |
| | self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1") |
| | |
| | |
| | s2_alignments = self.align_text_with_audio( |
| | data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language |
| | ) |
| | s2_alignment_data = { |
| | 'speaker': 'S2', |
| | 'text': data['prompt_text_speaker2'], |
| | 'audio_path': data['prompt_audio_speaker2'], |
| | 'language': alignment_language, |
| | 'alignments': s2_alignments |
| | } |
| | self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2") |
| | |
| | except Exception as e: |
| | self.logger.warning(f"保存预分割音频对齐信息失败: {e}") |
| | |
| | return data['prompt_audio_speaker1'], data['prompt_audio_speaker2'] |
| | |
| | |
| | elif 'prompt_audio' in data and 'prompt_text' in data: |
| | self.logger.info(f"从combined prompt音频分割speaker片段") |
| | return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id) |
| | |
| | else: |
| | raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件") |
| |
|
| | def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float: |
| | """ |
| | 计算两个音频的音色相似度(向后兼容版本) |
| | 对于过短的音频片段,通过复制来达到最小长度要求 |
| | """ |
| | |
| | if threading.current_thread() != threading.main_thread(): |
| | return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path) |
| | |
| | |
| | self._init_models_if_needed() |
| | |
| | try: |
| | if not os.path.exists(audio1_path) or not os.path.exists(audio2_path): |
| | self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}") |
| | return None |
| | |
| | |
| | def process_audio_for_similarity(audio_path, min_duration=0.1): |
| | """ |
| | 处理音频文件,如果过短则复制到满足最小长度要求 |
| | 返回处理后的音频路径和是否为临时文件的标志 |
| | """ |
| | try: |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| | duration = waveform.shape[1] / sample_rate |
| | |
| | if duration >= min_duration: |
| | |
| | return audio_path, False |
| | |
| | |
| | repeat_times = math.ceil(min_duration / duration) |
| | self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}") |
| | |
| | |
| | repeated_waveform = waveform.repeat(1, repeat_times) |
| | |
| | |
| | temp_filename = f"temp_{os.path.basename(audio_path)}" |
| | temp_path = str(self.temp_dir / temp_filename) |
| | |
| | |
| | torchaudio.save(temp_path, repeated_waveform, sample_rate) |
| | |
| | return temp_path, True |
| | |
| | except Exception as e: |
| | self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}") |
| | return audio_path, False |
| | |
| | |
| | processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path) |
| | processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path) |
| | |
| | |
| | similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2) |
| | |
| | |
| | if is_temp1 and os.path.exists(processed_audio1): |
| | try: |
| | os.remove(processed_audio1) |
| | except Exception as e: |
| | self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}") |
| | |
| | if is_temp2 and os.path.exists(processed_audio2): |
| | try: |
| | os.remove(processed_audio2) |
| | except Exception as e: |
| | self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}") |
| | |
| | return float(similarity) |
| | |
| | except Exception as e: |
| | |
| | if "choose a window size" in str(e) or "window size" in str(e): |
| | self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}") |
| | return None |
| | else: |
| | self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}") |
| | return None |
| |
|
| | |
| | def process_data_chunk_incremental(args): |
| | """处理数据块的工作函数(增量写入版本)""" |
| | data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_dir, language, similarity_max_workers = args |
| | |
| | |
| | device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu" |
| | |
| | try: |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | torch.cuda.set_device(gpu_id) |
| | |
| | time.sleep(proc_id * 0.5) |
| | |
| | |
| | evaluator = SpeakerSimilarityEvaluator( |
| | device=device, |
| | alignment_model_dir=alignment_model_dir, |
| | wespeaker_model_dir=wespeaker_model_dir, |
| | output_dir=output_dir, |
| | language=language, |
| | similarity_max_workers=similarity_max_workers |
| | ) |
| | |
| | |
| | evaluator._init_models_if_needed() |
| | |
| | |
| | if os.path.exists(temp_result_file): |
| | os.remove(temp_result_file) |
| | |
| | |
| | for i, data in enumerate(data_chunk): |
| | input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}" |
| | |
| | try: |
| | result = evaluator.evaluate_single_input(data, input_id=input_id) |
| | |
| | |
| | evaluator.append_result_to_jsonl(result, temp_result_file) |
| | |
| | print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})") |
| | |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | except Exception as e: |
| | print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}") |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | continue |
| | |
| | print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}") |
| | |
| | except Exception as e: |
| | print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}") |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | def main(): |
| | """主函数示例""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator') |
| | parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径') |
| | parser.add_argument('--output_dir', type=str, |
| | default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
| | help='结果保存目录') |
| | parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh', |
| | help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)') |
| | parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)') |
| | parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)') |
| | parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)') |
| | parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)') |
| | parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if args.random_seed is not None: |
| | random.seed(args.random_seed) |
| | np.random.seed(args.random_seed) |
| | torch.manual_seed(args.random_seed) |
| | print(f"设置随机种子: {args.random_seed}") |
| | |
| | |
| | language = args.language.upper() |
| | if language == 'AUTO': |
| | language = 'AUTO' |
| | elif language == 'EN': |
| | language = 'EN' |
| | else: |
| | language = 'ZH' |
| | |
| | |
| | evaluator = SpeakerSimilarityEvaluator( |
| | output_dir=args.output_dir, |
| | language=language, |
| | similarity_max_workers=args.similarity_workers |
| | ) |
| | |
| | |
| | use_parallel = not args.no_parallel |
| | use_shuffle = not args.no_shuffle |
| | |
| | print(f"使用语言设置: {language}") |
| | print(f"相似度计算线程数: {args.similarity_workers}") |
| | |
| | if args.jsonl_path: |
| | |
| | if use_parallel: |
| | evaluator.process_batch_from_jsonl_parallel( |
| | args.jsonl_path, |
| | processes_per_gpu=args.processes_per_gpu, |
| | shuffle_data=use_shuffle |
| | ) |
| | else: |
| | evaluator.process_batch_from_jsonl(args.jsonl_path) |
| | else: |
| | |
| | input_data = [ |
| | { |
| | 'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav", |
| | 'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。", |
| | 'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。", |
| | 'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav" |
| | } |
| | ] |
| | |
| | |
| | if use_parallel: |
| | evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu) |
| | else: |
| | evaluator.process_batch_from_data(input_data) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |