JeffreyZhou798 commited on
Commit
578a3ba
·
verified ·
1 Parent(s): 8e3dac3

Update backend/multi_voice_engine.py

Browse files
Files changed (1) hide show
  1. backend/multi_voice_engine.py +291 -213
backend/multi_voice_engine.py CHANGED
@@ -1,213 +1,291 @@
1
- """
2
- Multi-Voice Engine Module
3
- Handles SoulX-Singer model inference for multiple voices
4
- Implements segment-based processing for long scores
5
- """
6
-
7
- import numpy as np
8
- import torch
9
- from typing import Dict, List, Optional, Callable
10
- import gc
11
-
12
- from .config import get_inference_config
13
-
14
-
15
- class MultiVoiceEngine:
16
- """
17
- Multi-voice synthesis engine using SoulX-Singer.
18
-
19
- Features:
20
- - Segment-based processing for long scores (≤8s per segment)
21
- - Memory management with garbage collection
22
- - Progress callback support
23
- """
24
-
25
- def __init__(self, model):
26
- """
27
- Initialize engine with SoulX-Singer model.
28
-
29
- Args:
30
- model: SoulX-Singer model instance
31
- """
32
- self.model = model
33
- self.config = get_inference_config()
34
-
35
- def generate_single_voice(
36
- self,
37
- metadata: Dict,
38
- on_progress: Optional[Callable[[float], None]] = None
39
- ) -> np.ndarray:
40
- """
41
- Generate audio for a single voice.
42
-
43
- Args:
44
- metadata: Voice metadata from metadata_generator
45
- on_progress: Progress callback function
46
-
47
- Returns:
48
- Generated audio array
49
- """
50
- target = metadata['target']
51
- prompt_audio = metadata['prompt_audio']
52
-
53
- # Check if segmentation is needed
54
- total_duration = target['duration']
55
- segment_duration = self.config['segment_duration']
56
-
57
- if total_duration <= segment_duration:
58
- # Single segment
59
- return self._generate_segment(prompt_audio, target, on_progress)
60
- else:
61
- # Multiple segments
62
- return self._generate_segments(prompt_audio, target, on_progress)
63
-
64
- def _generate_segment(
65
- self,
66
- prompt_audio: np.ndarray,
67
- target: Dict,
68
- on_progress: Optional[Callable[[float], None]] = None
69
- ) -> np.ndarray:
70
- """
71
- Generate a single segment (≤8 seconds).
72
-
73
- Args:
74
- prompt_audio: Prompt audio array
75
- target: Target metadata
76
- on_progress: Progress callback
77
-
78
- Returns:
79
- Generated audio for this segment
80
- """
81
- try:
82
- # Prepare model input
83
- infer_data = {
84
- 'prompt': {
85
- 'waveform': torch.from_numpy(prompt_audio).float(),
86
- 'phoneme': self._phonemes_to_tensor(target['phoneme'][:len(prompt_audio)//100]),
87
- 'note_pitch': torch.tensor(target['note_pitch'][:len(prompt_audio)//100]),
88
- 'note_type': torch.tensor(target['note_type'][:len(prompt_audio)//100])
89
- },
90
- 'target': {
91
- 'phoneme': self._phonemes_to_tensor(target['phoneme']),
92
- 'note_pitch': torch.tensor(target['note_pitch']),
93
- 'note_type': torch.tensor(target['note_type'])
94
- }
95
- }
96
-
97
- # Run inference
98
- with torch.no_grad():
99
- output = self.model.infer(
100
- infer_data,
101
- auto_shift=False,
102
- pitch_shift=0,
103
- n_steps=self.config['n_steps'],
104
- cfg=self.config['cfg'],
105
- control=self.config['control'],
106
- use_fp16=self.config['use_fp16']
107
- )
108
-
109
- # Clean up
110
- del infer_data
111
- gc.collect()
112
-
113
- if on_progress:
114
- on_progress(100.0)
115
-
116
- return output.cpu().numpy() if torch.is_tensor(output) else output
117
-
118
- except Exception as e:
119
- print(f"Error in _generate_segment: {e}")
120
- # Fallback: return silence
121
- duration = target.get('duration', 1.0)
122
- return np.zeros(int(44100 * duration))
123
-
124
- def _generate_segments(
125
- self,
126
- prompt_audio: np.ndarray,
127
- target: Dict,
128
- on_progress: Optional[Callable[[float], None]] = None
129
- ) -> np.ndarray:
130
- """
131
- Generate multiple segments and concatenate.
132
-
133
- Args:
134
- prompt_audio: Prompt audio
135
- target: Target metadata
136
- on_progress: Progress callback
137
-
138
- Returns:
139
- Concatenated generated audio
140
- """
141
- total_duration = target['duration']
142
- segment_duration = self.config['segment_duration']
143
- num_segments = int(np.ceil(total_duration / segment_duration))
144
-
145
- segments = []
146
-
147
- for i in range(num_segments):
148
- # Extract segment metadata
149
- start_time = i * segment_duration
150
- end_time = min((i + 1) * segment_duration, total_duration)
151
-
152
- segment_target = self._extract_segment(target, start_time, end_time)
153
-
154
- # Generate this segment
155
- segment_audio = self._generate_segment(prompt_audio, segment_target)
156
- segments.append(segment_audio)
157
-
158
- # Update progress
159
- if on_progress:
160
- progress = (i + 1) / num_segments * 100
161
- on_progress(progress)
162
-
163
- # Memory cleanup
164
- gc.collect()
165
-
166
- # Concatenate segments
167
- return np.concatenate(segments)
168
-
169
- def _extract_segment(
170
- self,
171
- target: Dict,
172
- start_time: float,
173
- end_time: float
174
- ) -> Dict:
175
- """
176
- Extract a time segment from target metadata.
177
-
178
- Args:
179
- target: Full target metadata
180
- start_time: Segment start time (seconds)
181
- end_time: Segment end time (seconds)
182
-
183
- Returns:
184
- Segment metadata
185
- """
186
- # Simplified: just return full target for now
187
- # TODO: Implement proper time-based extraction
188
- return {
189
- 'phoneme': target['phoneme'],
190
- 'note_pitch': target['note_pitch'],
191
- 'note_type': target['note_type'],
192
- 'duration': end_time - start_time
193
- }
194
-
195
- def _phonemes_to_tensor(self, phonemes: List[str]) -> torch.Tensor:
196
- """
197
- Convert phoneme list to tensor.
198
-
199
- Args:
200
- phonemes: List of phoneme strings
201
-
202
- Returns:
203
- Phoneme tensor
204
- """
205
- # Simplified: convert to indices
206
- # TODO: Use proper phoneme vocabulary
207
- phoneme_to_idx = {
208
- 'd ow': 0, 'r ey': 1, 'm iy': 2, 'f aa': 3,
209
- 's ow l': 4, 'l aa': 5, 't iy': 6
210
- }
211
-
212
- indices = [phoneme_to_idx.get(p, 0) for p in phonemes]
213
- return torch.tensor(indices, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Voice Engine Module
3
+ Handles SoulX-Singer model inference for multiple voices
4
+ Implements segment-based processing for long scores
5
+ """
6
+
7
+ import numpy as np
8
+ import torch
9
+ from typing import Dict, List, Optional, Callable
10
+ import gc
11
+ import os
12
+ import sys
13
+
14
+ from .config import get_inference_config, get_device
15
+
16
+
17
+ class MultiVoiceEngine:
18
+ """
19
+ Multi-voice synthesis engine using SoulX-Singer.
20
+
21
+ Features:
22
+ - Segment-based processing for long scores (≤8s per segment)
23
+ - Memory management with garbage collection
24
+ - Progress callback support
25
+ - Uses DataProcessor for proper mel2note generation
26
+ """
27
+
28
+ def __init__(self, model):
29
+ """
30
+ Initialize engine with SoulX-Singer model.
31
+
32
+ Args:
33
+ model: SoulX-Singer model instance
34
+ """
35
+ self.model = model
36
+ self.config = get_inference_config()
37
+ self.device = get_device()
38
+ self._data_processor = None
39
+
40
+ def _get_data_processor(self):
41
+ """
42
+ Lazy load DataProcessor with proper configuration.
43
+
44
+ Returns:
45
+ DataProcessor instance
46
+ """
47
+ if self._data_processor is None:
48
+ # Add soulxsinger to path
49
+ base_path = os.path.dirname(__file__)
50
+ soulx_path = os.path.join(base_path, '..', 'soulxsinger')
51
+ if os.path.exists(soulx_path):
52
+ sys.path.insert(0, os.path.dirname(soulx_path))
53
+
54
+ from soulxsinger.utils.data_processor import DataProcessor
55
+
56
+ # DataProcessor config from soulxsinger.yaml
57
+ # hop_size=480, sample_rate=24000
58
+ self._data_processor = DataProcessor(
59
+ hop_size=480,
60
+ sample_rate=24000,
61
+ device=self.device
62
+ )
63
+
64
+ return self._data_processor
65
+
66
+ def generate_single_voice(
67
+ self,
68
+ metadata: Dict,
69
+ on_progress: Optional[Callable[[float], None]] = None
70
+ ) -> np.ndarray:
71
+ """
72
+ Generate audio for a single voice.
73
+
74
+ Args:
75
+ metadata: Voice metadata from metadata_generator
76
+ on_progress: Progress callback function
77
+
78
+ Returns:
79
+ Generated audio array
80
+ """
81
+ target = metadata['target']
82
+ prompt_audio = metadata['prompt_audio']
83
+
84
+ # Check if segmentation is needed
85
+ total_duration = target['duration']
86
+ segment_duration = self.config['segment_duration']
87
+
88
+ if total_duration <= segment_duration:
89
+ # Single segment
90
+ return self._generate_segment(prompt_audio, target, on_progress)
91
+ else:
92
+ # Multiple segments
93
+ return self._generate_segments(prompt_audio, target, on_progress)
94
+
95
+ def _generate_segment(
96
+ self,
97
+ prompt_audio: np.ndarray,
98
+ target: Dict,
99
+ on_progress: Optional[Callable[[float], None]] = None
100
+ ) -> np.ndarray:
101
+ """
102
+ Generate a single segment (≤8 seconds).
103
+
104
+ Args:
105
+ prompt_audio: Prompt audio array
106
+ target: Target metadata
107
+ on_progress: Progress callback
108
+
109
+ Returns:
110
+ Generated audio for this segment
111
+ """
112
+ try:
113
+ # Get DataProcessor for mel2note generation
114
+ data_processor = self._get_data_processor()
115
+
116
+ # Prepare target data using DataProcessor.preprocess
117
+ # This generates mel2note properly
118
+ target_data = data_processor.preprocess(
119
+ note_duration=target['note_duration'], # List[float] in seconds
120
+ phonemes=target['phoneme'], # List[str]
121
+ note_pitch=target['note_pitch'], # List[int]
122
+ note_type=target['note_type'] # List[int]
123
+ )
124
+
125
+ # Prepare prompt data
126
+ prompt_duration = len(prompt_audio) / 24000 # sample_rate=24000
127
+ prompt_phonemes = target['phoneme'][:min(5, len(target['phoneme']))]
128
+ prompt_pitches = target['note_pitch'][:min(5, len(target['note_pitch']))]
129
+ prompt_types = target['note_type'][:min(5, len(target['note_type']))]
130
+ prompt_durations = [prompt_duration / len(prompt_phonemes)] * len(prompt_phonemes)
131
+
132
+ prompt_data = data_processor.preprocess(
133
+ note_duration=prompt_durations,
134
+ phonemes=prompt_phonemes,
135
+ note_pitch=prompt_pitches,
136
+ note_type=prompt_types
137
+ )
138
+
139
+ # Add waveforms
140
+ prompt_data['waveform'] = torch.from_numpy(prompt_audio).float().unsqueeze(0).to(self.device)
141
+
142
+ # Build infer_data for model
143
+ infer_data = {
144
+ 'prompt': prompt_data,
145
+ 'target': target_data
146
+ }
147
+
148
+ # Run inference
149
+ with torch.no_grad():
150
+ output = self.model.infer(
151
+ infer_data,
152
+ auto_shift=False,
153
+ pitch_shift=0,
154
+ n_steps=self.config['n_steps'],
155
+ cfg=self.config['cfg'],
156
+ control=self.config['control'],
157
+ use_fp16=self.config['use_fp16']
158
+ )
159
+
160
+ # Clean up
161
+ del infer_data
162
+ del prompt_data
163
+ del target_data
164
+ gc.collect()
165
+
166
+ if on_progress:
167
+ on_progress(100.0)
168
+
169
+ # Convert to numpy
170
+ if torch.is_tensor(output):
171
+ output = output.cpu().numpy()
172
+
173
+ # Flatten if needed
174
+ if len(output.shape) > 1:
175
+ output = output.flatten()
176
+
177
+ return output
178
+
179
+ except Exception as e:
180
+ print(f"Error in _generate_segment: {e}")
181
+ import traceback
182
+ traceback.print_exc()
183
+ # Fallback: return silence
184
+ duration = target.get('duration', 1.0)
185
+ return np.zeros(int(24000 * duration))
186
+
187
+ def _generate_segments(
188
+ self,
189
+ prompt_audio: np.ndarray,
190
+ target: Dict,
191
+ on_progress: Optional[Callable[[float], None]] = None
192
+ ) -> np.ndarray:
193
+ """
194
+ Generate multiple segments and concatenate.
195
+
196
+ Args:
197
+ prompt_audio: Prompt audio
198
+ target: Target metadata
199
+ on_progress: Progress callback
200
+
201
+ Returns:
202
+ Concatenated generated audio
203
+ """
204
+ total_duration = target['duration']
205
+ segment_duration = self.config['segment_duration']
206
+ num_segments = int(np.ceil(total_duration / segment_duration))
207
+
208
+ segments = []
209
+
210
+ for i in range(num_segments):
211
+ # Extract segment metadata
212
+ start_time = i * segment_duration
213
+ end_time = min((i + 1) * segment_duration, total_duration)
214
+
215
+ segment_target = self._extract_segment(target, start_time, end_time)
216
+
217
+ # Generate this segment
218
+ segment_audio = self._generate_segment(prompt_audio, segment_target)
219
+ segments.append(segment_audio)
220
+
221
+ # Update progress
222
+ if on_progress:
223
+ progress = (i + 1) / num_segments * 100
224
+ on_progress(progress)
225
+
226
+ # Memory cleanup
227
+ gc.collect()
228
+
229
+ # Concatenate segments
230
+ return np.concatenate(segments)
231
+
232
+ def _extract_segment(
233
+ self,
234
+ target: Dict,
235
+ start_time: float,
236
+ end_time: float
237
+ ) -> Dict:
238
+ """
239
+ Extract a time segment from target metadata.
240
+
241
+ Args:
242
+ target: Full target metadata
243
+ start_time: Segment start time (seconds)
244
+ end_time: Segment end time (seconds)
245
+
246
+ Returns:
247
+ Segment metadata
248
+ """
249
+ # Calculate which notes fall within this segment
250
+ note_durations = target['note_duration']
251
+ phonemes = target['phoneme']
252
+ note_pitches = target['note_pitch']
253
+ note_types = target['note_type']
254
+
255
+ seg_durations = []
256
+ seg_phonemes = []
257
+ seg_pitches = []
258
+ seg_types = []
259
+
260
+ current_time = 0.0
261
+
262
+ for i, dur in enumerate(note_durations):
263
+ note_start = current_time
264
+ note_end = current_time + dur
265
+
266
+ # Check if this note overlaps with segment
267
+ if note_end > start_time and note_start < end_time:
268
+ # Calculate overlap
269
+ overlap_start = max(note_start, start_time)
270
+ overlap_end = min(note_end, end_time)
271
+ overlap_duration = overlap_end - overlap_start
272
+
273
+ if overlap_duration > 0:
274
+ seg_durations.append(overlap_duration)
275
+ seg_phonemes.append(phonemes[i])
276
+ seg_pitches.append(note_pitches[i])
277
+ seg_types.append(note_types[i])
278
+
279
+ current_time = note_end
280
+
281
+ # Stop if we've passed the segment end
282
+ if current_time >= end_time:
283
+ break
284
+
285
+ return {
286
+ 'phoneme': seg_phonemes,
287
+ 'note_pitch': seg_pitches,
288
+ 'note_duration': seg_durations,
289
+ 'note_type': seg_types,
290
+ 'duration': end_time - start_time
291
+ }