kokole commited on
Commit
339c325
·
1 Parent(s): 1fd58ee

add fp16 suport for svc

Browse files
cli/inference_svc.py CHANGED
@@ -67,8 +67,18 @@ def process(args, config, model: torch.nn.Module):
67
  n_step = args.n_steps if hasattr(args, "n_steps") else config.infer.n_steps
68
  cfg = args.cfg if hasattr(args, "cfg") else config.infer.cfg
69
 
70
- generated_audio, generated_shift = model.infer(pt_wav, gt_wav, pt_f0, gt_f0, auto_shift=args.auto_shift, pitch_shift=args.pitch_shift, n_steps=n_step, cfg=cfg)
71
- generated_audio = generated_audio.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
72
  if args.pitch_shift != generated_shift:
73
  args.pitch_shift = generated_shift
74
  # print(f"Applied pitch shift of {generated_shift} semitones to match GT F0 contour.")
@@ -99,7 +109,14 @@ if __name__ == "__main__":
99
  parser.add_argument("--pitch_shift", type=int, default=0)
100
  parser.add_argument("--n_steps", type=int, default=32)
101
  parser.add_argument("--cfg", type=float, default=3.0)
 
 
 
 
 
 
102
  args = parser.parse_args()
103
-
104
  config = load_config(args.config)
 
105
  main(args, config)
 
67
  n_step = args.n_steps if hasattr(args, "n_steps") else config.infer.n_steps
68
  cfg = args.cfg if hasattr(args, "cfg") else config.infer.cfg
69
 
70
+ generated_audio, generated_shift = model.infer(
71
+ pt_wav=pt_wav,
72
+ gt_wav=gt_wav,
73
+ pt_f0=pt_f0,
74
+ gt_f0=gt_f0,
75
+ auto_shift=args.auto_shift,
76
+ pitch_shift=args.pitch_shift,
77
+ n_steps=n_step,
78
+ cfg=cfg,
79
+ use_fp16=args.use_fp16,
80
+ )
81
+ generated_audio = generated_audio.squeeze().float().cpu().numpy()
82
  if args.pitch_shift != generated_shift:
83
  args.pitch_shift = generated_shift
84
  # print(f"Applied pitch shift of {generated_shift} semitones to match GT F0 contour.")
 
109
  parser.add_argument("--pitch_shift", type=int, default=0)
110
  parser.add_argument("--n_steps", type=int, default=32)
111
  parser.add_argument("--cfg", type=float, default=3.0)
112
+ parser.add_argument(
113
+ "--fp16",
114
+ action="store_true",
115
+ default=False,
116
+ help="Use FP16 inference (faster on GPU)",
117
+ )
118
  args = parser.parse_args()
119
+
120
  config = load_config(args.config)
121
+ args.use_fp16 = args.fp16
122
  main(args, config)
soulxsinger/models/soulxsinger_svc.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  import numpy as np
5
  from tqdm import tqdm
6
  from typing import Optional, Dict, Any, List, Tuple
 
7
 
8
  from soulxsinger.models.modules.vocoder import Vocoder
9
  from soulxsinger.models.modules.decoder import CFMDecoder
@@ -11,6 +12,10 @@ from soulxsinger.models.modules.mel_transform import MelSpectrogramEncoder
11
  from soulxsinger.models.modules.whisper_encoder import WhisperEncoder
12
 
13
 
 
 
 
 
14
  class SoulXSingerSVC(nn.Module):
15
  """
16
  SoulXSinger SVC model.
@@ -186,6 +191,7 @@ class SoulXSingerSVC(nn.Module):
186
  pitch_shift=0,
187
  n_steps=32,
188
  cfg=3,
 
189
  ):
190
  """
191
  SVC inference pipeline. First build vocal segments based on F0 contour, then run inference for each segment and merge results.
@@ -198,6 +204,7 @@ class SoulXSingerSVC(nn.Module):
198
  pitch_shift: manual pitch shift in semitones (overrides auto_shift if > 0)
199
  n_steps: number of diffusion steps for inference
200
  cfg: classifier-free guidance scale for inference
 
201
  """
202
 
203
  # calculate auto pitch shift
@@ -212,17 +219,29 @@ class SoulXSingerSVC(nn.Module):
212
  else:
213
  pitch_shift = pitch_shift
214
 
 
 
 
 
 
 
 
 
 
 
215
  # if target audio is less than 30 seconds, infer the whole audio
216
  if gt_wav.shape[-1] < 30 * self.audio_cfg.sample_rate:
217
- generated_audio = self.infer_segment(
218
- pt_wav=pt_wav,
219
- gt_wav=gt_wav,
220
- pt_f0=pt_f0,
221
- gt_f0=gt_f0,
222
- pitch_shift=pitch_shift,
223
- n_steps=n_steps,
224
- cfg=cfg,
225
- )
 
 
226
  return generated_audio, pitch_shift
227
 
228
  # if target audio is longer than 30 seconds, build vocal segments and infer each segment
@@ -258,15 +277,17 @@ class SoulXSingerSVC(nn.Module):
258
 
259
  segment_gt_wav = gt_wav[:, wav_start:wav_end]
260
  segment_gt_f0 = gt_f0[:, f0_start:f0_end]
261
- segment_generated_audio = self.infer_segment(
262
- pt_wav=pt_wav,
263
- gt_wav=segment_gt_wav,
264
- pt_f0=pt_f0,
265
- gt_f0=segment_gt_f0,
266
- pitch_shift=pitch_shift,
267
- n_steps=n_steps,
268
- cfg=cfg,
269
- )
 
 
270
 
271
  segment_start = int(round(seg_start_sec * self.audio_cfg.sample_rate))
272
  segment_end = int(round(seg_end_sec * self.audio_cfg.sample_rate))
@@ -276,8 +297,7 @@ class SoulXSingerSVC(nn.Module):
276
 
277
  return generated_audio, pitch_shift
278
 
279
- def infer_segment(self, pt_wav, gt_wav, pt_f0, gt_f0, pitch_shift=0, n_steps=32, cfg=3):
280
- pt_mel = self.mel(pt_wav)
281
  len_prompt_mel = pt_mel.shape[1]
282
  pt_f0 = F.pad(pt_f0, (0, 0, 0, max(0, len_prompt_mel - pt_f0.shape[1])))[:, :len_prompt_mel]
283
 
@@ -308,7 +328,7 @@ class SoulXSingerSVC(nn.Module):
308
  )
309
 
310
  generated_audio = self.vocoder(generated_mel.transpose(1, 2)[0:1, ...])
311
- generated_audio = generated_audio.squeeze()
312
 
313
  # cut or pad to match gt_wav length
314
  if generated_audio.shape[-1] > gt_wav.shape[-1]:
 
4
  import numpy as np
5
  from tqdm import tqdm
6
  from typing import Optional, Dict, Any, List, Tuple
7
+ from contextlib import nullcontext
8
 
9
  from soulxsinger.models.modules.vocoder import Vocoder
10
  from soulxsinger.models.modules.decoder import CFMDecoder
 
12
  from soulxsinger.models.modules.whisper_encoder import WhisperEncoder
13
 
14
 
15
+ def _autocast_if(enabled: bool):
16
+ """Return autocast(context) if enabled else no-op context. Use: with _autocast_if(use_amp): ..."""
17
+ return torch.amp.autocast(device_type="cuda", enabled=True) if enabled else nullcontext()
18
+
19
  class SoulXSingerSVC(nn.Module):
20
  """
21
  SoulXSinger SVC model.
 
191
  pitch_shift=0,
192
  n_steps=32,
193
  cfg=3,
194
+ use_fp16=True,
195
  ):
196
  """
197
  SVC inference pipeline. First build vocal segments based on F0 contour, then run inference for each segment and merge results.
 
204
  pitch_shift: manual pitch shift in semitones (overrides auto_shift if > 0)
205
  n_steps: number of diffusion steps for inference
206
  cfg: classifier-free guidance scale for inference
207
+ use_fp16: if True, run in FP16 except mel extraction to save memory and speed.
208
  """
209
 
210
  # calculate auto pitch shift
 
219
  else:
220
  pitch_shift = pitch_shift
221
 
222
+ use_fp16 = use_fp16 and pt_wav.is_cuda
223
+ with torch.amp.autocast(device_type="cuda", enabled=False):
224
+ pt_mel = self.mel(pt_wav.float() if pt_wav.dtype != torch.float32 else pt_wav)
225
+ if use_fp16:
226
+ pt_mel = pt_mel.half()
227
+ pt_wav = pt_wav.half()
228
+ gt_wav = gt_wav.half()
229
+ pt_f0 = pt_f0.half()
230
+ gt_f0 = gt_f0.half()
231
+
232
  # if target audio is less than 30 seconds, infer the whole audio
233
  if gt_wav.shape[-1] < 30 * self.audio_cfg.sample_rate:
234
+ with _autocast_if(use_fp16):
235
+ generated_audio = self.infer_segment(
236
+ pt_mel=pt_mel,
237
+ pt_wav=pt_wav,
238
+ gt_wav=gt_wav,
239
+ pt_f0=pt_f0,
240
+ gt_f0=gt_f0,
241
+ pitch_shift=pitch_shift,
242
+ n_steps=n_steps,
243
+ cfg=cfg,
244
+ )
245
  return generated_audio, pitch_shift
246
 
247
  # if target audio is longer than 30 seconds, build vocal segments and infer each segment
 
277
 
278
  segment_gt_wav = gt_wav[:, wav_start:wav_end]
279
  segment_gt_f0 = gt_f0[:, f0_start:f0_end]
280
+ with _autocast_if(use_fp16):
281
+ segment_generated_audio = self.infer_segment(
282
+ pt_mel=pt_mel,
283
+ pt_wav=pt_wav,
284
+ gt_wav=segment_gt_wav,
285
+ pt_f0=pt_f0,
286
+ gt_f0=segment_gt_f0,
287
+ pitch_shift=pitch_shift,
288
+ n_steps=n_steps,
289
+ cfg=cfg,
290
+ )
291
 
292
  segment_start = int(round(seg_start_sec * self.audio_cfg.sample_rate))
293
  segment_end = int(round(seg_end_sec * self.audio_cfg.sample_rate))
 
297
 
298
  return generated_audio, pitch_shift
299
 
300
+ def infer_segment(self, pt_mel, pt_wav, gt_wav, pt_f0, gt_f0, pitch_shift=0, n_steps=32, cfg=3):
 
301
  len_prompt_mel = pt_mel.shape[1]
302
  pt_f0 = F.pad(pt_f0, (0, 0, 0, max(0, len_prompt_mel - pt_f0.shape[1])))[:, :len_prompt_mel]
303
 
 
328
  )
329
 
330
  generated_audio = self.vocoder(generated_mel.transpose(1, 2)[0:1, ...])
331
+ generated_audio = generated_audio.squeeze().float()
332
 
333
  # cut or pad to match gt_wav length
334
  if generated_audio.shape[-1] > gt_wav.shape[-1]:
webui_svc.py CHANGED
@@ -175,6 +175,7 @@ class AppState:
175
  pitch_shift: int,
176
  n_step: int,
177
  cfg: float,
 
178
  seed: int,
179
  ) -> tuple[bool, str, Path | None]:
180
  try:
@@ -199,6 +200,7 @@ class AppState:
199
  args.pitch_shift = int(pitch_shift)
200
  args.n_steps = int(n_step)
201
  args.cfg = float(cfg)
 
202
 
203
  svc_process(args, self.svc_config, self.svc_model)
204
 
@@ -306,6 +308,7 @@ def _run_svc_convert(
306
  pitch_shift=0,
307
  n_step=32,
308
  cfg=1.0,
 
309
  seed=42,
310
  ):
311
  try:
@@ -325,6 +328,7 @@ def _run_svc_convert(
325
  pitch_shift=int(pitch_shift),
326
  n_step=int(n_step),
327
  cfg=float(cfg),
 
328
  seed=int(seed),
329
  )
330
  if not ok or generated is None:
@@ -351,12 +355,13 @@ def _start_svc(
351
  pitch_shift=0,
352
  n_step=32,
353
  cfg=1.0,
 
354
  seed=42,
355
  ):
356
  state = _run_svc_preprocess(prompt_audio, target_audio, prompt_vocal_sep, target_vocal_sep)
357
  if state is None:
358
  return None
359
- return _run_svc_convert(state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, seed)
360
 
361
 
362
  def render_tab_content() -> None:
@@ -387,8 +392,10 @@ def render_tab_content() -> None:
387
  with gr.Row():
388
  auto_shift = gr.Checkbox(label="Auto pitch shift", value=True, scale=1)
389
  auto_mix_acc = gr.Checkbox(label="Auto mix accompaniment", value=True, scale=1)
 
 
390
  pitch_shift = gr.Slider(label="Pitch shift (semitones)", value=0, minimum=-36, maximum=36, step=1)
391
- n_step = gr.Slider(label="n_step", value=32, minimum=1, maximum=200, step=1)
392
  cfg = gr.Slider(label="cfg scale", value=1.0, minimum=0.0, maximum=10.0, step=0.1)
393
  seed_input = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1)
394
 
@@ -411,7 +418,7 @@ def render_tab_content() -> None:
411
  outputs=[svc_state],
412
  ).then(
413
  fn=_run_svc_convert,
414
- inputs=[svc_state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, seed_input],
415
  outputs=[output_audio],
416
  )
417
 
 
175
  pitch_shift: int,
176
  n_step: int,
177
  cfg: float,
178
+ use_fp16: bool,
179
  seed: int,
180
  ) -> tuple[bool, str, Path | None]:
181
  try:
 
200
  args.pitch_shift = int(pitch_shift)
201
  args.n_steps = int(n_step)
202
  args.cfg = float(cfg)
203
+ args.use_fp16 = bool(use_fp16)
204
 
205
  svc_process(args, self.svc_config, self.svc_model)
206
 
 
308
  pitch_shift=0,
309
  n_step=32,
310
  cfg=1.0,
311
+ use_fp16=True,
312
  seed=42,
313
  ):
314
  try:
 
328
  pitch_shift=int(pitch_shift),
329
  n_step=int(n_step),
330
  cfg=float(cfg),
331
+ use_fp16=bool(use_fp16),
332
  seed=int(seed),
333
  )
334
  if not ok or generated is None:
 
355
  pitch_shift=0,
356
  n_step=32,
357
  cfg=1.0,
358
+ use_fp16=True,
359
  seed=42,
360
  ):
361
  state = _run_svc_preprocess(prompt_audio, target_audio, prompt_vocal_sep, target_vocal_sep)
362
  if state is None:
363
  return None
364
+ return _run_svc_convert(state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, use_fp16, seed)
365
 
366
 
367
  def render_tab_content() -> None:
 
392
  with gr.Row():
393
  auto_shift = gr.Checkbox(label="Auto pitch shift", value=True, scale=1)
394
  auto_mix_acc = gr.Checkbox(label="Auto mix accompaniment", value=True, scale=1)
395
+ with gr.Row():
396
+ use_fp16 = gr.Checkbox(label="Use FP16", value=True, scale=1)
397
  pitch_shift = gr.Slider(label="Pitch shift (semitones)", value=0, minimum=-36, maximum=36, step=1)
398
+ n_step = gr.Slider(label="diffusion steps", value=32, minimum=1, maximum=200, step=1)
399
  cfg = gr.Slider(label="cfg scale", value=1.0, minimum=0.0, maximum=10.0, step=0.1)
400
  seed_input = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1)
401
 
 
418
  outputs=[svc_state],
419
  ).then(
420
  fn=_run_svc_convert,
421
+ inputs=[svc_state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, use_fp16, seed_input],
422
  outputs=[output_audio],
423
  )
424