SolfegeScore-Singer-01 / backend /multi_voice_engine.py
JeffreyZhou798's picture
Update backend/multi_voice_engine.py
578a3ba verified
"""
Multi-Voice Engine Module
Handles SoulX-Singer model inference for multiple voices
Implements segment-based processing for long scores
"""
import numpy as np
import torch
from typing import Dict, List, Optional, Callable
import gc
import os
import sys
from .config import get_inference_config, get_device
class MultiVoiceEngine:
"""
Multi-voice synthesis engine using SoulX-Singer.
Features:
- Segment-based processing for long scores (≤8s per segment)
- Memory management with garbage collection
- Progress callback support
- Uses DataProcessor for proper mel2note generation
"""
def __init__(self, model):
"""
Initialize engine with SoulX-Singer model.
Args:
model: SoulX-Singer model instance
"""
self.model = model
self.config = get_inference_config()
self.device = get_device()
self._data_processor = None
def _get_data_processor(self):
"""
Lazy load DataProcessor with proper configuration.
Returns:
DataProcessor instance
"""
if self._data_processor is None:
# Add soulxsinger to path
base_path = os.path.dirname(__file__)
soulx_path = os.path.join(base_path, '..', 'soulxsinger')
if os.path.exists(soulx_path):
sys.path.insert(0, os.path.dirname(soulx_path))
from soulxsinger.utils.data_processor import DataProcessor
# DataProcessor config from soulxsinger.yaml
# hop_size=480, sample_rate=24000
self._data_processor = DataProcessor(
hop_size=480,
sample_rate=24000,
device=self.device
)
return self._data_processor
def generate_single_voice(
self,
metadata: Dict,
on_progress: Optional[Callable[[float], None]] = None
) -> np.ndarray:
"""
Generate audio for a single voice.
Args:
metadata: Voice metadata from metadata_generator
on_progress: Progress callback function
Returns:
Generated audio array
"""
target = metadata['target']
prompt_audio = metadata['prompt_audio']
# Check if segmentation is needed
total_duration = target['duration']
segment_duration = self.config['segment_duration']
if total_duration <= segment_duration:
# Single segment
return self._generate_segment(prompt_audio, target, on_progress)
else:
# Multiple segments
return self._generate_segments(prompt_audio, target, on_progress)
def _generate_segment(
self,
prompt_audio: np.ndarray,
target: Dict,
on_progress: Optional[Callable[[float], None]] = None
) -> np.ndarray:
"""
Generate a single segment (≤8 seconds).
Args:
prompt_audio: Prompt audio array
target: Target metadata
on_progress: Progress callback
Returns:
Generated audio for this segment
"""
try:
# Get DataProcessor for mel2note generation
data_processor = self._get_data_processor()
# Prepare target data using DataProcessor.preprocess
# This generates mel2note properly
target_data = data_processor.preprocess(
note_duration=target['note_duration'], # List[float] in seconds
phonemes=target['phoneme'], # List[str]
note_pitch=target['note_pitch'], # List[int]
note_type=target['note_type'] # List[int]
)
# Prepare prompt data
prompt_duration = len(prompt_audio) / 24000 # sample_rate=24000
prompt_phonemes = target['phoneme'][:min(5, len(target['phoneme']))]
prompt_pitches = target['note_pitch'][:min(5, len(target['note_pitch']))]
prompt_types = target['note_type'][:min(5, len(target['note_type']))]
prompt_durations = [prompt_duration / len(prompt_phonemes)] * len(prompt_phonemes)
prompt_data = data_processor.preprocess(
note_duration=prompt_durations,
phonemes=prompt_phonemes,
note_pitch=prompt_pitches,
note_type=prompt_types
)
# Add waveforms
prompt_data['waveform'] = torch.from_numpy(prompt_audio).float().unsqueeze(0).to(self.device)
# Build infer_data for model
infer_data = {
'prompt': prompt_data,
'target': target_data
}
# Run inference
with torch.no_grad():
output = self.model.infer(
infer_data,
auto_shift=False,
pitch_shift=0,
n_steps=self.config['n_steps'],
cfg=self.config['cfg'],
control=self.config['control'],
use_fp16=self.config['use_fp16']
)
# Clean up
del infer_data
del prompt_data
del target_data
gc.collect()
if on_progress:
on_progress(100.0)
# Convert to numpy
if torch.is_tensor(output):
output = output.cpu().numpy()
# Flatten if needed
if len(output.shape) > 1:
output = output.flatten()
return output
except Exception as e:
print(f"Error in _generate_segment: {e}")
import traceback
traceback.print_exc()
# Fallback: return silence
duration = target.get('duration', 1.0)
return np.zeros(int(24000 * duration))
def _generate_segments(
self,
prompt_audio: np.ndarray,
target: Dict,
on_progress: Optional[Callable[[float], None]] = None
) -> np.ndarray:
"""
Generate multiple segments and concatenate.
Args:
prompt_audio: Prompt audio
target: Target metadata
on_progress: Progress callback
Returns:
Concatenated generated audio
"""
total_duration = target['duration']
segment_duration = self.config['segment_duration']
num_segments = int(np.ceil(total_duration / segment_duration))
segments = []
for i in range(num_segments):
# Extract segment metadata
start_time = i * segment_duration
end_time = min((i + 1) * segment_duration, total_duration)
segment_target = self._extract_segment(target, start_time, end_time)
# Generate this segment
segment_audio = self._generate_segment(prompt_audio, segment_target)
segments.append(segment_audio)
# Update progress
if on_progress:
progress = (i + 1) / num_segments * 100
on_progress(progress)
# Memory cleanup
gc.collect()
# Concatenate segments
return np.concatenate(segments)
def _extract_segment(
self,
target: Dict,
start_time: float,
end_time: float
) -> Dict:
"""
Extract a time segment from target metadata.
Args:
target: Full target metadata
start_time: Segment start time (seconds)
end_time: Segment end time (seconds)
Returns:
Segment metadata
"""
# Calculate which notes fall within this segment
note_durations = target['note_duration']
phonemes = target['phoneme']
note_pitches = target['note_pitch']
note_types = target['note_type']
seg_durations = []
seg_phonemes = []
seg_pitches = []
seg_types = []
current_time = 0.0
for i, dur in enumerate(note_durations):
note_start = current_time
note_end = current_time + dur
# Check if this note overlaps with segment
if note_end > start_time and note_start < end_time:
# Calculate overlap
overlap_start = max(note_start, start_time)
overlap_end = min(note_end, end_time)
overlap_duration = overlap_end - overlap_start
if overlap_duration > 0:
seg_durations.append(overlap_duration)
seg_phonemes.append(phonemes[i])
seg_pitches.append(note_pitches[i])
seg_types.append(note_types[i])
current_time = note_end
# Stop if we've passed the segment end
if current_time >= end_time:
break
return {
'phoneme': seg_phonemes,
'note_pitch': seg_pitches,
'note_duration': seg_durations,
'note_type': seg_types,
'duration': end_time - start_time
}