Files changed (1) hide show
  1. app.py +322 -159
app.py CHANGED
@@ -3,74 +3,128 @@ import gradio as gr
3
  import torch
4
  import torchaudio
5
  import librosa
 
6
  from modules.commons import build_model, load_checkpoint, recursive_munch
7
  import yaml
8
  from hf_utils import load_custom_model_from_hf
9
  import numpy as np
10
  from pydub import AudioSegment
11
 
12
- # Load model and configuration
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
17
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
18
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
19
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
 
 
 
 
20
  config = yaml.safe_load(open(dit_config_path, 'r'))
21
  model_params = recursive_munch(config['model_params'])
22
  model = build_model(model_params, stage='DiT')
23
  hop_length = config['preprocess_params']['spect_params']['hop_length']
24
  sr = config['preprocess_params']['sr']
25
 
26
- # Load checkpoints
27
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
28
- load_only_params=True, ignore_modules=[], is_distributed=False)
 
 
29
  for key in model:
30
  model[key].eval()
31
  model[key].to(device)
32
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
33
 
34
- # Load additional modules
35
- from modules.campplus.DTDNN import CAMPPlus
36
-
37
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
38
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
39
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
40
- campplus_model.eval()
41
- campplus_model.to(device)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  from modules.bigvgan import bigvgan
44
 
45
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
46
-
47
- # remove weight norm in the model and set to eval mode
 
48
  bigvgan_model.remove_weight_norm()
49
  bigvgan_model = bigvgan_model.eval().to(device)
50
 
 
 
 
51
  ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
52
-
53
  codec_config = yaml.safe_load(open(config_path))
54
  codec_model_params = recursive_munch(codec_config['model_params'])
55
  codec_encoder = build_model(codec_model_params, stage="codec")
56
 
57
  ckpt_params = torch.load(ckpt_path, map_location="cpu")
58
-
59
  for key in codec_encoder:
60
  codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
61
  _ = [codec_encoder[key].eval() for key in codec_encoder]
62
  _ = [codec_encoder[key].to(device) for key in codec_encoder]
63
 
64
- # whisper
 
 
65
  from transformers import AutoFeatureExtractor, WhisperModel
66
 
67
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
68
- 'whisper_name') else "openai/whisper-small"
 
 
 
69
  whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
70
  del whisper_model.decoder
71
  whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
72
 
73
- # Generate mel spectrograms
 
 
74
  mel_fn_args = {
75
  "n_fft": config['preprocess_params']['spect_params']['n_fft'],
76
  "win_size": config['preprocess_params']['spect_params']['win_length'],
@@ -82,51 +136,62 @@ mel_fn_args = {
82
  "center": False
83
  }
84
  from modules.audio import mel_spectrogram
85
-
86
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
87
 
88
- # f0 conditioned model
89
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
90
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
91
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
92
-
93
- config = yaml.safe_load(open(dit_config_path, 'r'))
94
- model_params = recursive_munch(config['model_params'])
95
- model_f0 = build_model(model_params, stage='DiT')
96
- hop_length = config['preprocess_params']['spect_params']['hop_length']
97
- sr = config['preprocess_params']['sr']
98
-
99
- # Load checkpoints
100
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
101
- load_only_params=True, ignore_modules=[], is_distributed=False)
 
 
 
 
 
 
102
  for key in model_f0:
103
  model_f0[key].eval()
104
  model_f0[key].to(device)
 
105
  model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
106
 
107
- # f0 extractor
108
  from modules.rmvpe import RMVPE
109
 
110
  model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
111
  rmvpe = RMVPE(model_path, is_half=False, device=device)
112
 
113
  mel_fn_args_f0 = {
114
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
115
- "win_size": config['preprocess_params']['spect_params']['win_length'],
116
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
117
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
118
- "sampling_rate": sr,
119
  "fmin": 0,
120
  "fmax": None,
121
  "center": False
122
  }
123
  to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
124
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
125
 
126
- # remove weight norm in the model and set to eval mode
 
 
 
127
  bigvgan_44k_model.remove_weight_norm()
128
  bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
129
 
 
 
 
130
  def adjust_f0_semitones(f0_sequence, n_semitones):
131
  factor = 2 ** (n_semitones / 12)
132
  return f0_sequence * factor
@@ -137,39 +202,86 @@ def crossfade(chunk1, chunk2, overlap):
137
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
138
  return chunk2
139
 
140
- # streaming and chunk processing related params
141
  bitrate = "320k"
142
  overlap_frame_len = 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @spaces.GPU
144
  @torch.no_grad()
145
  @torch.inference_mode()
146
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
 
 
147
  inference_module = model if not f0_condition else model_f0
148
  mel_fn = to_mel if not f0_condition else to_mel_f0
149
  bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
150
- sr = 22050 if not f0_condition else 44100
151
- hop_length = 256 if not f0_condition else 512
152
- max_context_window = sr // hop_length * 30
153
- overlap_wave_len = overlap_frame_len * hop_length
 
 
154
  # Load audio
155
- source_audio = librosa.load(source, sr=sr)[0]
156
- ref_audio = librosa.load(target, sr=sr)[0]
157
 
158
- # Process audio
159
  source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
160
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
 
 
 
 
161
 
162
- # Resample
163
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
164
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
165
- # if source audio less than 30 seconds, whisper can handle in one forward
166
  if converted_waves_16k.size(-1) <= 16000 * 30:
167
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
168
- return_tensors="pt",
169
- return_attention_mask=True,
170
- sampling_rate=16000)
 
 
171
  alt_input_features = whisper_model._mask_input_features(
172
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
 
 
173
  alt_outputs = whisper_model.encoder(
174
  alt_input_features.to(whisper_model.encoder.dtype),
175
  head_mask=None,
@@ -180,21 +292,28 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
180
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
181
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
182
  else:
183
- overlapping_time = 5 # 5 seconds
184
  S_alt_list = []
185
  buffer = None
186
  traversed_time = 0
187
  while traversed_time < converted_waves_16k.size(-1):
188
- if buffer is None: # first chunk
189
  chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
190
  else:
191
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
192
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
193
- return_tensors="pt",
194
- return_attention_mask=True,
195
- sampling_rate=16000)
 
 
 
 
 
196
  alt_input_features = whisper_model._mask_input_features(
197
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
 
 
198
  alt_outputs = whisper_model.encoder(
199
  alt_input_features.to(whisper_model.encoder.dtype),
200
  head_mask=None,
@@ -202,30 +321,35 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
202
  output_hidden_states=False,
203
  return_dict=True,
204
  )
205
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
206
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
207
  if traversed_time == 0:
208
- S_alt_list.append(S_alt)
209
  else:
210
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
211
  buffer = chunk[:, -16000 * overlapping_time:]
212
  traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
 
213
  S_alt = torch.cat(S_alt_list, dim=1)
214
 
215
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
216
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
217
- return_tensors="pt",
218
- return_attention_mask=True)
 
 
 
219
  ori_input_features = whisper_model._mask_input_features(
220
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
221
- with torch.no_grad():
222
- ori_outputs = whisper_model.encoder(
223
- ori_input_features.to(whisper_model.encoder.dtype),
224
- head_mask=None,
225
- output_attentions=False,
226
- output_hidden_states=False,
227
- return_dict=True,
228
- )
 
229
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
230
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
231
 
@@ -235,15 +359,12 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
235
  target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
236
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
237
 
238
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
239
- num_mel_bins=80,
240
- dither=0,
241
- sample_frequency=16000)
242
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
243
- style2 = campplus_model(feat2.unsqueeze(0))
244
 
 
245
  if f0_condition:
246
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
247
  F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
248
 
249
  F0_ori = torch.from_numpy(F0_ori).to(device)[None]
@@ -258,118 +379,160 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
258
  median_log_f0_ori = torch.median(voiced_log_f0_ori)
259
  median_log_f0_alt = torch.median(voiced_log_f0_alt)
260
 
261
- # shift alt log f0 level to ori log f0 level
262
  shifted_log_f0_alt = log_f0_alt.clone()
263
- if auto_f0_adjust:
264
  shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
 
265
  shifted_f0_alt = torch.exp(shifted_log_f0_alt)
266
  if pitch_shift != 0:
267
  shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
268
  else:
269
  F0_ori = None
270
- F0_alt = None
271
  shifted_f0_alt = None
272
 
273
  # Length regulation
274
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
275
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
 
 
 
 
276
 
277
  max_source_window = max_context_window - mel2.size(2)
278
- # split source condition (cond) into chunks
279
  processed_frames = 0
280
  generated_wave_chunks = []
281
- # generate chunk by chunk and stream the output
 
282
  while processed_frames < cond.size(1):
283
  chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
284
  is_last_chunk = processed_frames + max_source_window >= cond.size(1)
 
285
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
286
- with torch.autocast(device_type='cuda', dtype=torch.float16):
287
- # Voice Conversion
288
- vc_target = inference_module.cfm.inference(cat_condition,
289
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
290
- mel2, style2, None, diffusion_steps,
291
- inference_cfg_rate=inference_cfg_rate)
 
 
292
  vc_target = vc_target[:, :, mel2.size(-1):]
 
293
  vc_wave = bigvgan_fn(vc_target.float())[0]
 
294
  if processed_frames == 0:
295
  if is_last_chunk:
296
  output_wave = vc_wave[0].cpu().numpy()
297
  generated_wave_chunks.append(output_wave)
298
- output_wave = (output_wave * 32768.0).astype(np.int16)
 
299
  mp3_bytes = AudioSegment(
300
- output_wave.tobytes(), frame_rate=sr,
301
- sample_width=output_wave.dtype.itemsize, channels=1
 
 
302
  ).export(format="mp3", bitrate=bitrate).read()
303
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
304
  break
 
305
  output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
306
  generated_wave_chunks.append(output_wave)
307
  previous_chunk = vc_wave[0, -overlap_wave_len:]
308
  processed_frames += vc_target.size(2) - overlap_frame_len
309
- output_wave = (output_wave * 32768.0).astype(np.int16)
 
310
  mp3_bytes = AudioSegment(
311
- output_wave.tobytes(), frame_rate=sr,
312
- sample_width=output_wave.dtype.itemsize, channels=1
 
 
313
  ).export(format="mp3", bitrate=bitrate).read()
314
  yield mp3_bytes, None
 
315
  elif is_last_chunk:
316
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
317
  generated_wave_chunks.append(output_wave)
318
  processed_frames += vc_target.size(2) - overlap_frame_len
319
- output_wave = (output_wave * 32768.0).astype(np.int16)
 
320
  mp3_bytes = AudioSegment(
321
- output_wave.tobytes(), frame_rate=sr,
322
- sample_width=output_wave.dtype.itemsize, channels=1
 
 
323
  ).export(format="mp3", bitrate=bitrate).read()
324
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
325
  break
 
326
  else:
327
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
328
  generated_wave_chunks.append(output_wave)
329
  previous_chunk = vc_wave[0, -overlap_wave_len:]
330
  processed_frames += vc_target.size(2) - overlap_frame_len
331
- output_wave = (output_wave * 32768.0).astype(np.int16)
 
332
  mp3_bytes = AudioSegment(
333
- output_wave.tobytes(), frame_rate=sr,
334
- sample_width=output_wave.dtype.itemsize, channels=1
 
 
335
  ).export(format="mp3", bitrate=bitrate).read()
336
  yield mp3_bytes, None
337
 
338
-
 
 
339
  if __name__ == "__main__":
340
- description = ("State-of-the-Art zero-shot voice conversion/singing voice conversion. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
341
- "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
342
- "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
343
- "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
344
- "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
 
 
 
 
 
345
  inputs = [
346
  gr.Audio(type="filepath", label="Source Audio / 源音频"),
347
  gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
348
- gr.Slider(minimum=1, maximum=200, value=25, step=1, label="Diffusion Steps / 扩散步数", info="25 by default, 50~100 for best quality / 默认为 25,50~100 为最佳质量"),
349
- gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
350
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence / 有微小影响"),
351
- gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
 
 
 
 
 
 
352
  gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
353
- info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
354
- gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  ]
356
 
357
- examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
358
- ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, False, True, 0],
359
- ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
360
- "examples/reference/kobe_0.wav", 50, 1.0, 0.7, True, False, -6],
361
- ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
362
- "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
363
- ]
364
-
365
- outputs = [gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
366
- gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')]
367
-
368
- gr.Interface(fn=voice_conversion,
369
- description=description,
370
- inputs=inputs,
371
- outputs=outputs,
372
- title="Seed Voice Conversion",
373
- examples=examples,
374
- cache_examples=False,
375
- ).launch()
 
3
  import torch
4
  import torchaudio
5
  import librosa
6
+ import torch.nn as nn
7
  from modules.commons import build_model, load_checkpoint, recursive_munch
8
  import yaml
9
  from hf_utils import load_custom_model_from_hf
10
  import numpy as np
11
  from pydub import AudioSegment
12
 
13
+ # =========================================================
14
+ # Device
15
+ # =========================================================
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # =========================================================
19
+ # Load Seed-VC DiT model (non-f0)
20
+ # =========================================================
21
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
22
+ "Plachta/Seed-VC",
23
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
24
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"
25
+ )
26
+
27
  config = yaml.safe_load(open(dit_config_path, 'r'))
28
  model_params = recursive_munch(config['model_params'])
29
  model = build_model(model_params, stage='DiT')
30
  hop_length = config['preprocess_params']['spect_params']['hop_length']
31
  sr = config['preprocess_params']['sr']
32
 
33
+ model, _, _, _ = load_checkpoint(
34
+ model, None, dit_checkpoint_path,
35
+ load_only_params=True, ignore_modules=[],
36
+ is_distributed=False
37
+ )
38
  for key in model:
39
  model[key].eval()
40
  model[key].to(device)
 
41
 
42
+ # Cache setup
43
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
 
 
 
 
 
 
44
 
45
+ # =========================================================
46
+ # Speaker embedding: ECAPA (SpeechBrain) replacement
47
+ # - This reduces CN accent bias vs campplus_cn_common
48
+ # - Fallback to original CAMPPlus if SpeechBrain not available
49
+ # =========================================================
50
+ USE_ECAPA = True
51
+ spk_encoder = None
52
+
53
+ try:
54
+ from speechbrain.pretrained import EncoderClassifier
55
+ spk_encoder = EncoderClassifier.from_hparams(
56
+ source="speechbrain/spkrec-ecapa-voxceleb",
57
+ run_opts={"device": str(device)}
58
+ )
59
+ spk_encoder.eval()
60
+ except Exception as e:
61
+ # If SpeechBrain isn't installed/available, fallback to CAMPPlus
62
+ USE_ECAPA = False
63
+ spk_encoder = None
64
+ print("[WARN] SpeechBrain ECAPA not available. Falling back to CAMPPlus. Error:", str(e))
65
+
66
+ # CAMPPlus fallback (original)
67
+ campplus_model = None
68
+ if not USE_ECAPA:
69
+ from modules.campplus.DTDNN import CAMPPlus
70
+ campplus_ckpt_path = load_custom_model_from_hf(
71
+ "funasr/campplus",
72
+ "campplus_cn_common.bin",
73
+ config_filename=None
74
+ )
75
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
76
+ campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
77
+ campplus_model.eval()
78
+ campplus_model.to(device)
79
+
80
+ # A small projection to map ECAPA embedding dim -> expected style dim
81
+ # We build it lazily at first inference once we know ECAPA dim.
82
+ style_proj = None
83
+ STYLE_DIM_EXPECTED = 192 # CAMPPlus embedding_size used originally in this app
84
+
85
+ # =========================================================
86
+ # Vocoder (BigVGAN)
87
+ # =========================================================
88
  from modules.bigvgan import bigvgan
89
 
90
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained(
91
+ 'nvidia/bigvgan_v2_22khz_80band_256x',
92
+ use_cuda_kernel=False
93
+ )
94
  bigvgan_model.remove_weight_norm()
95
  bigvgan_model = bigvgan_model.eval().to(device)
96
 
97
+ # =========================================================
98
+ # Codec (FAcodec)
99
+ # =========================================================
100
  ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
 
101
  codec_config = yaml.safe_load(open(config_path))
102
  codec_model_params = recursive_munch(codec_config['model_params'])
103
  codec_encoder = build_model(codec_model_params, stage="codec")
104
 
105
  ckpt_params = torch.load(ckpt_path, map_location="cpu")
 
106
  for key in codec_encoder:
107
  codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
108
  _ = [codec_encoder[key].eval() for key in codec_encoder]
109
  _ = [codec_encoder[key].to(device) for key in codec_encoder]
110
 
111
+ # =========================================================
112
+ # Whisper encoder (content)
113
+ # =========================================================
114
  from transformers import AutoFeatureExtractor, WhisperModel
115
 
116
+ whisper_name = (
117
+ model_params.speech_tokenizer.whisper_name
118
+ if hasattr(model_params.speech_tokenizer, 'whisper_name')
119
+ else "openai/whisper-small"
120
+ )
121
  whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
122
  del whisper_model.decoder
123
  whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
124
 
125
+ # =========================================================
126
+ # Mel Spectrogram
127
+ # =========================================================
128
  mel_fn_args = {
129
  "n_fft": config['preprocess_params']['spect_params']['n_fft'],
130
  "win_size": config['preprocess_params']['spect_params']['win_length'],
 
136
  "center": False
137
  }
138
  from modules.audio import mel_spectrogram
 
139
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
140
 
141
+ # =========================================================
142
+ # Load Seed-VC DiT model (f0 conditioned)
143
+ # =========================================================
144
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
145
+ "Plachta/Seed-VC",
146
+ "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
147
+ "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml"
148
+ )
149
+
150
+ config_f0 = yaml.safe_load(open(dit_config_path, 'r'))
151
+ model_params_f0 = recursive_munch(config_f0['model_params'])
152
+ model_f0 = build_model(model_params_f0, stage='DiT')
153
+ hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
154
+ sr_f0 = config_f0['preprocess_params']['sr']
155
+
156
+ model_f0, _, _, _ = load_checkpoint(
157
+ model_f0, None, dit_checkpoint_path,
158
+ load_only_params=True, ignore_modules=[],
159
+ is_distributed=False
160
+ )
161
  for key in model_f0:
162
  model_f0[key].eval()
163
  model_f0[key].to(device)
164
+
165
  model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
166
 
167
+ # F0 extractor
168
  from modules.rmvpe import RMVPE
169
 
170
  model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
171
  rmvpe = RMVPE(model_path, is_half=False, device=device)
172
 
173
  mel_fn_args_f0 = {
174
+ "n_fft": config_f0['preprocess_params']['spect_params']['n_fft'],
175
+ "win_size": config_f0['preprocess_params']['spect_params']['win_length'],
176
+ "hop_size": config_f0['preprocess_params']['spect_params']['hop_length'],
177
+ "num_mels": config_f0['preprocess_params']['spect_params']['n_mels'],
178
+ "sampling_rate": sr_f0,
179
  "fmin": 0,
180
  "fmax": None,
181
  "center": False
182
  }
183
  to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
 
184
 
185
+ bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained(
186
+ 'nvidia/bigvgan_v2_44khz_128band_512x',
187
+ use_cuda_kernel=False
188
+ )
189
  bigvgan_44k_model.remove_weight_norm()
190
  bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
191
 
192
+ # =========================================================
193
+ # Helpers
194
+ # =========================================================
195
  def adjust_f0_semitones(f0_sequence, n_semitones):
196
  factor = 2 ** (n_semitones / 12)
197
  return f0_sequence * factor
 
202
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
203
  return chunk2
204
 
205
+ # Streaming and chunk params
206
  bitrate = "320k"
207
  overlap_frame_len = 16
208
+
209
+ def get_style_embedding(ref_waves_16k: torch.Tensor) -> torch.Tensor:
210
+ """
211
+ ref_waves_16k: (B, T) float tensor @ 16k
212
+ returns: style2 (B, STYLE_DIM_EXPECTED)
213
+ """
214
+ global style_proj
215
+
216
+ if USE_ECAPA and spk_encoder is not None:
217
+ with torch.no_grad():
218
+ # SpeechBrain ECAPA returns (B, 1, D) or (B, D) depending on version
219
+ emb = spk_encoder.encode_batch(ref_waves_16k)
220
+ if emb.dim() == 3:
221
+ emb = emb.squeeze(1) # (B, D)
222
+ style2 = emb.to(device)
223
+
224
+ # Project to expected style dim if needed
225
+ if style2.size(-1) != STYLE_DIM_EXPECTED:
226
+ if style_proj is None:
227
+ style_proj = nn.Linear(style2.size(-1), STYLE_DIM_EXPECTED).to(device)
228
+ style_proj.eval()
229
+ with torch.no_grad():
230
+ style2 = style_proj(style2)
231
+ return style2
232
+
233
+ # Fallback: CAMPPlus
234
+ feat2 = torchaudio.compliance.kaldi.fbank(
235
+ ref_waves_16k,
236
+ num_mel_bins=80,
237
+ dither=0,
238
+ sample_frequency=16000
239
+ )
240
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
241
+ style2 = campplus_model(feat2.unsqueeze(0))
242
+ return style2
243
+
244
+ # =========================================================
245
+ # Voice Conversion
246
+ # =========================================================
247
  @spaces.GPU
248
  @torch.no_grad()
249
  @torch.inference_mode()
250
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate,
251
+ f0_condition, auto_f0_adjust, pitch_shift):
252
+
253
  inference_module = model if not f0_condition else model_f0
254
  mel_fn = to_mel if not f0_condition else to_mel_f0
255
  bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
256
+ sr_local = 22050 if not f0_condition else 44100
257
+ hop_local = 256 if not f0_condition else 512
258
+
259
+ max_context_window = sr_local // hop_local * 30
260
+ overlap_wave_len = overlap_frame_len * hop_local
261
+
262
  # Load audio
263
+ source_audio = librosa.load(source, sr=sr_local)[0]
264
+ ref_audio = librosa.load(target, sr=sr_local)[0]
265
 
 
266
  source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
267
+ ref_audio = torch.tensor(ref_audio[:sr_local * 25]).unsqueeze(0).float().to(device)
268
+
269
+ # Resample for whisper and speaker embedding
270
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr_local, 16000)
271
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr_local, 16000)
272
 
273
+ # Whisper content encoding (S_alt)
 
 
 
274
  if converted_waves_16k.size(-1) <= 16000 * 30:
275
+ alt_inputs = whisper_feature_extractor(
276
+ [converted_waves_16k.squeeze(0).cpu().numpy()],
277
+ return_tensors="pt",
278
+ return_attention_mask=True,
279
+ sampling_rate=16000
280
+ )
281
  alt_input_features = whisper_model._mask_input_features(
282
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask
283
+ ).to(device)
284
+
285
  alt_outputs = whisper_model.encoder(
286
  alt_input_features.to(whisper_model.encoder.dtype),
287
  head_mask=None,
 
292
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
293
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
294
  else:
295
+ overlapping_time = 5
296
  S_alt_list = []
297
  buffer = None
298
  traversed_time = 0
299
  while traversed_time < converted_waves_16k.size(-1):
300
+ if buffer is None:
301
  chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
302
  else:
303
+ chunk = torch.cat(
304
+ [buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]],
305
+ dim=-1
306
+ )
307
+ alt_inputs = whisper_feature_extractor(
308
+ [chunk.squeeze(0).cpu().numpy()],
309
+ return_tensors="pt",
310
+ return_attention_mask=True,
311
+ sampling_rate=16000
312
+ )
313
  alt_input_features = whisper_model._mask_input_features(
314
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask
315
+ ).to(device)
316
+
317
  alt_outputs = whisper_model.encoder(
318
  alt_input_features.to(whisper_model.encoder.dtype),
319
  head_mask=None,
 
321
  output_hidden_states=False,
322
  return_dict=True,
323
  )
324
+ S_alt_chunk = alt_outputs.last_hidden_state.to(torch.float32)
325
+ S_alt_chunk = S_alt_chunk[:, :chunk.size(-1) // 320 + 1]
326
  if traversed_time == 0:
327
+ S_alt_list.append(S_alt_chunk)
328
  else:
329
+ S_alt_list.append(S_alt_chunk[:, 50 * overlapping_time:])
330
  buffer = chunk[:, -16000 * overlapping_time:]
331
  traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
332
+
333
  S_alt = torch.cat(S_alt_list, dim=1)
334
 
335
+ # Whisper prompt (S_ori)
336
+ ori_waves_16k = torchaudio.functional.resample(ref_audio, sr_local, 16000)
337
+ ori_inputs = whisper_feature_extractor(
338
+ [ori_waves_16k.squeeze(0).cpu().numpy()],
339
+ return_tensors="pt",
340
+ return_attention_mask=True
341
+ )
342
  ori_input_features = whisper_model._mask_input_features(
343
+ ori_inputs.input_features, attention_mask=ori_inputs.attention_mask
344
+ ).to(device)
345
+
346
+ ori_outputs = whisper_model.encoder(
347
+ ori_input_features.to(whisper_model.encoder.dtype),
348
+ head_mask=None,
349
+ output_attentions=False,
350
+ output_hidden_states=False,
351
+ return_dict=True,
352
+ )
353
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
354
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
355
 
 
359
  target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
360
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
361
 
362
+ # Speaker embedding (ECAPA or fallback)
363
+ style2 = get_style_embedding(ref_waves_16k)
 
 
 
 
364
 
365
+ # f0 handling
366
  if f0_condition:
367
+ F0_ori = rmvpe.infer_from_audio(ori_waves_16k[0], thred=0.5)
368
  F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
369
 
370
  F0_ori = torch.from_numpy(F0_ori).to(device)[None]
 
379
  median_log_f0_ori = torch.median(voiced_log_f0_ori)
380
  median_log_f0_alt = torch.median(voiced_log_f0_alt)
381
 
 
382
  shifted_log_f0_alt = log_f0_alt.clone()
383
+ if auto_f0_adjust and voiced_F0_alt.numel() > 0 and voiced_F0_ori.numel() > 0:
384
  shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
385
+
386
  shifted_f0_alt = torch.exp(shifted_log_f0_alt)
387
  if pitch_shift != 0:
388
  shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
389
  else:
390
  F0_ori = None
 
391
  shifted_f0_alt = None
392
 
393
  # Length regulation
394
+ cond, _, _, _, _ = inference_module.length_regulator(
395
+ S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt
396
+ )
397
+ prompt_condition, _, _, _, _ = inference_module.length_regulator(
398
+ S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori
399
+ )
400
 
401
  max_source_window = max_context_window - mel2.size(2)
402
+
403
  processed_frames = 0
404
  generated_wave_chunks = []
405
+ previous_chunk = None
406
+
407
  while processed_frames < cond.size(1):
408
  chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
409
  is_last_chunk = processed_frames + max_source_window >= cond.size(1)
410
+
411
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
412
+
413
+ with torch.autocast(device_type='cuda', dtype=torch.float16) if device.type == "cuda" else torch.no_grad():
414
+ vc_target = inference_module.cfm.inference(
415
+ cat_condition,
416
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
417
+ mel2, style2, None, diffusion_steps,
418
+ inference_cfg_rate=inference_cfg_rate
419
+ )
420
  vc_target = vc_target[:, :, mel2.size(-1):]
421
+
422
  vc_wave = bigvgan_fn(vc_target.float())[0]
423
+
424
  if processed_frames == 0:
425
  if is_last_chunk:
426
  output_wave = vc_wave[0].cpu().numpy()
427
  generated_wave_chunks.append(output_wave)
428
+ output_i16 = (output_wave * 32768.0).astype(np.int16)
429
+
430
  mp3_bytes = AudioSegment(
431
+ output_i16.tobytes(),
432
+ frame_rate=sr_local,
433
+ sample_width=output_i16.dtype.itemsize,
434
+ channels=1
435
  ).export(format="mp3", bitrate=bitrate).read()
436
+ yield mp3_bytes, (sr_local, np.concatenate(generated_wave_chunks))
437
  break
438
+
439
  output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
440
  generated_wave_chunks.append(output_wave)
441
  previous_chunk = vc_wave[0, -overlap_wave_len:]
442
  processed_frames += vc_target.size(2) - overlap_frame_len
443
+
444
+ output_i16 = (output_wave * 32768.0).astype(np.int16)
445
  mp3_bytes = AudioSegment(
446
+ output_i16.tobytes(),
447
+ frame_rate=sr_local,
448
+ sample_width=output_i16.dtype.itemsize,
449
+ channels=1
450
  ).export(format="mp3", bitrate=bitrate).read()
451
  yield mp3_bytes, None
452
+
453
  elif is_last_chunk:
454
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
455
  generated_wave_chunks.append(output_wave)
456
  processed_frames += vc_target.size(2) - overlap_frame_len
457
+
458
+ output_i16 = (output_wave * 32768.0).astype(np.int16)
459
  mp3_bytes = AudioSegment(
460
+ output_i16.tobytes(),
461
+ frame_rate=sr_local,
462
+ sample_width=output_i16.dtype.itemsize,
463
+ channels=1
464
  ).export(format="mp3", bitrate=bitrate).read()
465
+ yield mp3_bytes, (sr_local, np.concatenate(generated_wave_chunks))
466
  break
467
+
468
  else:
469
  output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
470
  generated_wave_chunks.append(output_wave)
471
  previous_chunk = vc_wave[0, -overlap_wave_len:]
472
  processed_frames += vc_target.size(2) - overlap_frame_len
473
+
474
+ output_i16 = (output_wave * 32768.0).astype(np.int16)
475
  mp3_bytes = AudioSegment(
476
+ output_i16.tobytes(),
477
+ frame_rate=sr_local,
478
+ sample_width=output_i16.dtype.itemsize,
479
+ channels=1
480
  ).export(format="mp3", bitrate=bitrate).read()
481
  yield mp3_bytes, None
482
 
483
+ # =========================================================
484
+ # Gradio UI
485
+ # =========================================================
486
  if __name__ == "__main__":
487
+ description = (
488
+ "State-of-the-Art zero-shot voice conversion/singing voice conversion. "
489
+ "For local deployment please check GitHub repository for details and updates.<br>"
490
+ "Note: reference audio will be clipped to 25s if longer.<br>"
491
+ "If total duration exceeds 30s, source audio will be processed in chunks.<br>"
492
+ "<br>"
493
+ "Hindi tip: Use Hindi SOURCE + Hindi REFERENCE for best Hindi output. "
494
+ "This app converts voice (audio→audio), it does not do text-to-speech."
495
+ )
496
+
497
  inputs = [
498
  gr.Audio(type="filepath", label="Source Audio / 源音频"),
499
  gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
500
+ gr.Slider(minimum=1, maximum=200, value=25, step=1,
501
+ label="Diffusion Steps / 扩散步数",
502
+ info="25 by default, 50~100 for best quality / 默认为 25,50~100 为最佳质量"),
503
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0,
504
+ label="Length Adjust / 长度调整",
505
+ info="<1.0 speed-up, >1.0 slow-down / <1.0 加速,>1.0 减速"),
506
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7,
507
+ label="Inference CFG Rate", info="subtle influence / 有微小影响"),
508
+ gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False,
509
+ info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
510
  gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
511
+ info="Roughly adjust F0 to match target voice. Only when F0 model is used."),
512
+ gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0,
513
+ info="Semitones. Only when F0 model is used / 半音,仅F0模型生效"),
514
+ ]
515
+
516
+ examples = [
517
+ ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
518
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, False, True, 0],
519
+ ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
520
+ "examples/reference/kobe_0.wav", 50, 1.0, 0.7, True, False, -6],
521
+ ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
522
+ "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
523
+ ]
524
+
525
+ outputs = [
526
+ gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
527
+ gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
528
  ]
529
 
530
+ gr.Interface(
531
+ fn=voice_conversion,
532
+ description=description,
533
+ inputs=inputs,
534
+ outputs=outputs,
535
+ title="Seed Voice Conversion (ECAPA speaker embedding)",
536
+ examples=examples,
537
+ cache_examples=False
538
+ ).launch()