| """ |
| Multi-Voice Engine Module |
| Handles SoulX-Singer model inference for multiple voices |
| Implements segment-based processing for long scores |
| """ |
|
|
| import numpy as np |
| import torch |
| from typing import Dict, List, Optional, Callable |
| import gc |
| import os |
| import sys |
|
|
| from .config import get_inference_config, get_device |
|
|
|
|
| class MultiVoiceEngine: |
| """ |
| Multi-voice synthesis engine using SoulX-Singer. |
| |
| Features: |
| - Segment-based processing for long scores (≤8s per segment) |
| - Memory management with garbage collection |
| - Progress callback support |
| - Uses DataProcessor for proper mel2note generation |
| """ |
| |
| def __init__(self, model): |
| """ |
| Initialize engine with SoulX-Singer model. |
| |
| Args: |
| model: SoulX-Singer model instance |
| """ |
| self.model = model |
| self.config = get_inference_config() |
| self.device = get_device() |
| self._data_processor = None |
| |
| def _get_data_processor(self): |
| """ |
| Lazy load DataProcessor with proper configuration. |
| |
| Returns: |
| DataProcessor instance |
| """ |
| if self._data_processor is None: |
| |
| base_path = os.path.dirname(__file__) |
| soulx_path = os.path.join(base_path, '..', 'soulxsinger') |
| if os.path.exists(soulx_path): |
| sys.path.insert(0, os.path.dirname(soulx_path)) |
| |
| from soulxsinger.utils.data_processor import DataProcessor |
| |
| |
| |
| self._data_processor = DataProcessor( |
| hop_size=480, |
| sample_rate=24000, |
| device=self.device |
| ) |
| |
| return self._data_processor |
| |
| def generate_single_voice( |
| self, |
| metadata: Dict, |
| on_progress: Optional[Callable[[float], None]] = None |
| ) -> np.ndarray: |
| """ |
| Generate audio for a single voice. |
| |
| Args: |
| metadata: Voice metadata from metadata_generator |
| on_progress: Progress callback function |
| |
| Returns: |
| Generated audio array |
| """ |
| target = metadata['target'] |
| prompt_audio = metadata['prompt_audio'] |
| |
| |
| total_duration = target['duration'] |
| segment_duration = self.config['segment_duration'] |
| |
| if total_duration <= segment_duration: |
| |
| return self._generate_segment(prompt_audio, target, on_progress) |
| else: |
| |
| return self._generate_segments(prompt_audio, target, on_progress) |
| |
| def _generate_segment( |
| self, |
| prompt_audio: np.ndarray, |
| target: Dict, |
| on_progress: Optional[Callable[[float], None]] = None |
| ) -> np.ndarray: |
| """ |
| Generate a single segment (≤8 seconds). |
| |
| Args: |
| prompt_audio: Prompt audio array |
| target: Target metadata |
| on_progress: Progress callback |
| |
| Returns: |
| Generated audio for this segment |
| """ |
| try: |
| |
| data_processor = self._get_data_processor() |
| |
| |
| |
| target_data = data_processor.preprocess( |
| note_duration=target['note_duration'], |
| phonemes=target['phoneme'], |
| note_pitch=target['note_pitch'], |
| note_type=target['note_type'] |
| ) |
| |
| |
| prompt_duration = len(prompt_audio) / 24000 |
| prompt_phonemes = target['phoneme'][:min(5, len(target['phoneme']))] |
| prompt_pitches = target['note_pitch'][:min(5, len(target['note_pitch']))] |
| prompt_types = target['note_type'][:min(5, len(target['note_type']))] |
| prompt_durations = [prompt_duration / len(prompt_phonemes)] * len(prompt_phonemes) |
| |
| prompt_data = data_processor.preprocess( |
| note_duration=prompt_durations, |
| phonemes=prompt_phonemes, |
| note_pitch=prompt_pitches, |
| note_type=prompt_types |
| ) |
| |
| |
| prompt_data['waveform'] = torch.from_numpy(prompt_audio).float().unsqueeze(0).to(self.device) |
| |
| |
| infer_data = { |
| 'prompt': prompt_data, |
| 'target': target_data |
| } |
| |
| |
| with torch.no_grad(): |
| output = self.model.infer( |
| infer_data, |
| auto_shift=False, |
| pitch_shift=0, |
| n_steps=self.config['n_steps'], |
| cfg=self.config['cfg'], |
| control=self.config['control'], |
| use_fp16=self.config['use_fp16'] |
| ) |
| |
| |
| del infer_data |
| del prompt_data |
| del target_data |
| gc.collect() |
| |
| if on_progress: |
| on_progress(100.0) |
| |
| |
| if torch.is_tensor(output): |
| output = output.cpu().numpy() |
| |
| |
| if len(output.shape) > 1: |
| output = output.flatten() |
| |
| return output |
| |
| except Exception as e: |
| print(f"Error in _generate_segment: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| duration = target.get('duration', 1.0) |
| return np.zeros(int(24000 * duration)) |
| |
| def _generate_segments( |
| self, |
| prompt_audio: np.ndarray, |
| target: Dict, |
| on_progress: Optional[Callable[[float], None]] = None |
| ) -> np.ndarray: |
| """ |
| Generate multiple segments and concatenate. |
| |
| Args: |
| prompt_audio: Prompt audio |
| target: Target metadata |
| on_progress: Progress callback |
| |
| Returns: |
| Concatenated generated audio |
| """ |
| total_duration = target['duration'] |
| segment_duration = self.config['segment_duration'] |
| num_segments = int(np.ceil(total_duration / segment_duration)) |
| |
| segments = [] |
| |
| for i in range(num_segments): |
| |
| start_time = i * segment_duration |
| end_time = min((i + 1) * segment_duration, total_duration) |
| |
| segment_target = self._extract_segment(target, start_time, end_time) |
| |
| |
| segment_audio = self._generate_segment(prompt_audio, segment_target) |
| segments.append(segment_audio) |
| |
| |
| if on_progress: |
| progress = (i + 1) / num_segments * 100 |
| on_progress(progress) |
| |
| |
| gc.collect() |
| |
| |
| return np.concatenate(segments) |
| |
| def _extract_segment( |
| self, |
| target: Dict, |
| start_time: float, |
| end_time: float |
| ) -> Dict: |
| """ |
| Extract a time segment from target metadata. |
| |
| Args: |
| target: Full target metadata |
| start_time: Segment start time (seconds) |
| end_time: Segment end time (seconds) |
| |
| Returns: |
| Segment metadata |
| """ |
| |
| note_durations = target['note_duration'] |
| phonemes = target['phoneme'] |
| note_pitches = target['note_pitch'] |
| note_types = target['note_type'] |
| |
| seg_durations = [] |
| seg_phonemes = [] |
| seg_pitches = [] |
| seg_types = [] |
| |
| current_time = 0.0 |
| |
| for i, dur in enumerate(note_durations): |
| note_start = current_time |
| note_end = current_time + dur |
| |
| |
| if note_end > start_time and note_start < end_time: |
| |
| overlap_start = max(note_start, start_time) |
| overlap_end = min(note_end, end_time) |
| overlap_duration = overlap_end - overlap_start |
| |
| if overlap_duration > 0: |
| seg_durations.append(overlap_duration) |
| seg_phonemes.append(phonemes[i]) |
| seg_pitches.append(note_pitches[i]) |
| seg_types.append(note_types[i]) |
| |
| current_time = note_end |
| |
| |
| if current_time >= end_time: |
| break |
| |
| return { |
| 'phoneme': seg_phonemes, |
| 'note_pitch': seg_pitches, |
| 'note_duration': seg_durations, |
| 'note_type': seg_types, |
| 'duration': end_time - start_time |
| } |
|
|