Approximetal commited on
Commit
3e1b384
·
verified ·
1 Parent(s): 9f66cd3

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +35 -9
inference_gradio.py CHANGED
@@ -110,7 +110,6 @@ class UVR5:
110
 
111
  denoise_model = UVR5(
112
  model_dir=Path(PRETRAINED_ROOT) / "uvr5",
113
- code_dir=REPO_ROOT / "uvr5",
114
  )
115
 
116
  def load_wav(audio_info, sr=16000, channel=1):
@@ -130,11 +129,9 @@ def load_wav(audio_info, sr=16000, channel=1):
130
 
131
 
132
  def denoise(audio_info):
133
- save_path = "./denoised_audio.wav"
134
  denoised_audio, sr = denoise_model.denoise(audio_info)
135
- sf.write(save_path, denoised_audio, sr, format='wav', subtype='PCM_24')
136
- print("save denoised audio:", save_path)
137
- return save_path
138
 
139
  def cancel_denoise(audio_info):
140
  return audio_info
@@ -240,8 +237,22 @@ def infer(
240
  if not os.path.isfile(ckpt_resolved):
241
  return None, "Checkpoint not found!", ""
242
 
243
- if denoise_audio:
244
- ref_audio = denoise_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  # Automatically enable prosody encoder when using the prosody checkpoint
247
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
@@ -274,6 +285,9 @@ def infer(
274
  )
275
  except Exception as e:
276
  traceback.print_exc()
 
 
 
277
  return None, f"Error loading model: {str(e)}", ""
278
 
279
  print("Model loaded >>", file_checkpoint, use_ema)
@@ -284,7 +298,7 @@ def infer(
284
  try:
285
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
286
  tts_api.infer(
287
- ref_file=ref_audio,
288
  ref_text=ref_text.strip(),
289
  gen_text=gen_text.strip(),
290
  nfe_step=nfe_step,
@@ -303,6 +317,10 @@ def infer(
303
  except Exception as e:
304
  traceback.print_exc()
305
  return None, f"Inference error: {str(e)}", ""
 
 
 
 
306
 
307
 
308
  def get_gpu_stats():
@@ -457,7 +475,15 @@ with gr.Blocks(title="LEMAS-TTS Inference") as app:
457
  with gr.Row():
458
  denoise_btn = gr.Button(value="Denoise")
459
  cancel_btn = gr.Button(value="Cancel Denoise")
460
- denoise_audio = gr.Audio(label="Denoised Audio", value=None, type="filepath", interactive=True, show_download_button=True, editable=True)
 
 
 
 
 
 
 
 
461
 
462
  gen_text = gr.Textbox(label="Text to Generate", placeholder="Enter the text you want to generate...")
463
 
 
110
 
111
  denoise_model = UVR5(
112
  model_dir=Path(PRETRAINED_ROOT) / "uvr5",
 
113
  )
114
 
115
  def load_wav(audio_info, sr=16000, channel=1):
 
129
 
130
 
131
  def denoise(audio_info):
132
+ # Return a numpy waveform tuple for direct playback in Gradio.
133
  denoised_audio, sr = denoise_model.denoise(audio_info)
134
+ return (sr, denoised_audio)
 
 
135
 
136
  def cancel_denoise(audio_info):
137
  return audio_info
 
237
  if not os.path.isfile(ckpt_resolved):
238
  return None, "Checkpoint not found!", ""
239
 
240
+ # Prepare reference audio:
241
+ # - `ref_audio` from Gradio is a filepath (original reference)
242
+ # - `denoise_audio` is an optional (sr, numpy_array) tuple from UVR5.
243
+ # If provided, dump it to a temporary wav file and use that as ref_file.
244
+ ref_audio_path = ref_audio
245
+ tmp_ref_path = None
246
+ if denoise_audio is not None:
247
+ try:
248
+ sr_d, wav_d = denoise_audio
249
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f_ref:
250
+ sf.write(f_ref.name, wav_d, int(sr_d), format="wav", subtype="PCM_24")
251
+ tmp_ref_path = f_ref.name
252
+ ref_audio_path = f_ref.name
253
+ except Exception as e:
254
+ traceback.print_exc()
255
+ return None, f"Error preparing denoised reference audio: {str(e)}", ""
256
 
257
  # Automatically enable prosody encoder when using the prosody checkpoint
258
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
 
285
  )
286
  except Exception as e:
287
  traceback.print_exc()
288
+ # Cleanup temp ref file if it was created
289
+ if tmp_ref_path is not None and os.path.isfile(tmp_ref_path):
290
+ os.remove(tmp_ref_path)
291
  return None, f"Error loading model: {str(e)}", ""
292
 
293
  print("Model loaded >>", file_checkpoint, use_ema)
 
298
  try:
299
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
300
  tts_api.infer(
301
+ ref_file=ref_audio_path,
302
  ref_text=ref_text.strip(),
303
  gen_text=gen_text.strip(),
304
  nfe_step=nfe_step,
 
317
  except Exception as e:
318
  traceback.print_exc()
319
  return None, f"Inference error: {str(e)}", ""
320
+ finally:
321
+ # Remove temporary reference file if created
322
+ if tmp_ref_path is not None and os.path.isfile(tmp_ref_path):
323
+ os.remove(tmp_ref_path)
324
 
325
 
326
  def get_gpu_stats():
 
475
  with gr.Row():
476
  denoise_btn = gr.Button(value="Denoise")
477
  cancel_btn = gr.Button(value="Cancel Denoise")
478
+ # Use numpy type here so we can reuse the waveform directly in Python.
479
+ denoise_audio = gr.Audio(
480
+ label="Denoised Audio",
481
+ value=None,
482
+ type="numpy",
483
+ interactive=True,
484
+ show_download_button=True,
485
+ editable=True,
486
+ )
487
 
488
  gen_text = gr.Textbox(label="Text to Generate", placeholder="Enter the text you want to generate...")
489