Spaces:
Running
on
Zero
Running
on
Zero
Update inference_gradio.py
Browse files- inference_gradio.py +35 -9
inference_gradio.py
CHANGED
|
@@ -110,7 +110,6 @@ class UVR5:
|
|
| 110 |
|
| 111 |
denoise_model = UVR5(
|
| 112 |
model_dir=Path(PRETRAINED_ROOT) / "uvr5",
|
| 113 |
-
code_dir=REPO_ROOT / "uvr5",
|
| 114 |
)
|
| 115 |
|
| 116 |
def load_wav(audio_info, sr=16000, channel=1):
|
|
@@ -130,11 +129,9 @@ def load_wav(audio_info, sr=16000, channel=1):
|
|
| 130 |
|
| 131 |
|
| 132 |
def denoise(audio_info):
|
| 133 |
-
|
| 134 |
denoised_audio, sr = denoise_model.denoise(audio_info)
|
| 135 |
-
|
| 136 |
-
print("save denoised audio:", save_path)
|
| 137 |
-
return save_path
|
| 138 |
|
| 139 |
def cancel_denoise(audio_info):
|
| 140 |
return audio_info
|
|
@@ -240,8 +237,22 @@ def infer(
|
|
| 240 |
if not os.path.isfile(ckpt_resolved):
|
| 241 |
return None, "Checkpoint not found!", ""
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# Automatically enable prosody encoder when using the prosody checkpoint
|
| 247 |
use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
|
|
@@ -274,6 +285,9 @@ def infer(
|
|
| 274 |
)
|
| 275 |
except Exception as e:
|
| 276 |
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
| 277 |
return None, f"Error loading model: {str(e)}", ""
|
| 278 |
|
| 279 |
print("Model loaded >>", file_checkpoint, use_ema)
|
|
@@ -284,7 +298,7 @@ def infer(
|
|
| 284 |
try:
|
| 285 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 286 |
tts_api.infer(
|
| 287 |
-
ref_file=
|
| 288 |
ref_text=ref_text.strip(),
|
| 289 |
gen_text=gen_text.strip(),
|
| 290 |
nfe_step=nfe_step,
|
|
@@ -303,6 +317,10 @@ def infer(
|
|
| 303 |
except Exception as e:
|
| 304 |
traceback.print_exc()
|
| 305 |
return None, f"Inference error: {str(e)}", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
def get_gpu_stats():
|
|
@@ -457,7 +475,15 @@ with gr.Blocks(title="LEMAS-TTS Inference") as app:
|
|
| 457 |
with gr.Row():
|
| 458 |
denoise_btn = gr.Button(value="Denoise")
|
| 459 |
cancel_btn = gr.Button(value="Cancel Denoise")
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
gen_text = gr.Textbox(label="Text to Generate", placeholder="Enter the text you want to generate...")
|
| 463 |
|
|
|
|
| 110 |
|
| 111 |
denoise_model = UVR5(
|
| 112 |
model_dir=Path(PRETRAINED_ROOT) / "uvr5",
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
def load_wav(audio_info, sr=16000, channel=1):
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
def denoise(audio_info):
|
| 132 |
+
# Return a numpy waveform tuple for direct playback in Gradio.
|
| 133 |
denoised_audio, sr = denoise_model.denoise(audio_info)
|
| 134 |
+
return (sr, denoised_audio)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
def cancel_denoise(audio_info):
|
| 137 |
return audio_info
|
|
|
|
| 237 |
if not os.path.isfile(ckpt_resolved):
|
| 238 |
return None, "Checkpoint not found!", ""
|
| 239 |
|
| 240 |
+
# Prepare reference audio:
|
| 241 |
+
# - `ref_audio` from Gradio is a filepath (original reference)
|
| 242 |
+
# - `denoise_audio` is an optional (sr, numpy_array) tuple from UVR5.
|
| 243 |
+
# If provided, dump it to a temporary wav file and use that as ref_file.
|
| 244 |
+
ref_audio_path = ref_audio
|
| 245 |
+
tmp_ref_path = None
|
| 246 |
+
if denoise_audio is not None:
|
| 247 |
+
try:
|
| 248 |
+
sr_d, wav_d = denoise_audio
|
| 249 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f_ref:
|
| 250 |
+
sf.write(f_ref.name, wav_d, int(sr_d), format="wav", subtype="PCM_24")
|
| 251 |
+
tmp_ref_path = f_ref.name
|
| 252 |
+
ref_audio_path = f_ref.name
|
| 253 |
+
except Exception as e:
|
| 254 |
+
traceback.print_exc()
|
| 255 |
+
return None, f"Error preparing denoised reference audio: {str(e)}", ""
|
| 256 |
|
| 257 |
# Automatically enable prosody encoder when using the prosody checkpoint
|
| 258 |
use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
|
|
|
|
| 285 |
)
|
| 286 |
except Exception as e:
|
| 287 |
traceback.print_exc()
|
| 288 |
+
# Cleanup temp ref file if it was created
|
| 289 |
+
if tmp_ref_path is not None and os.path.isfile(tmp_ref_path):
|
| 290 |
+
os.remove(tmp_ref_path)
|
| 291 |
return None, f"Error loading model: {str(e)}", ""
|
| 292 |
|
| 293 |
print("Model loaded >>", file_checkpoint, use_ema)
|
|
|
|
| 298 |
try:
|
| 299 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 300 |
tts_api.infer(
|
| 301 |
+
ref_file=ref_audio_path,
|
| 302 |
ref_text=ref_text.strip(),
|
| 303 |
gen_text=gen_text.strip(),
|
| 304 |
nfe_step=nfe_step,
|
|
|
|
| 317 |
except Exception as e:
|
| 318 |
traceback.print_exc()
|
| 319 |
return None, f"Inference error: {str(e)}", ""
|
| 320 |
+
finally:
|
| 321 |
+
# Remove temporary reference file if created
|
| 322 |
+
if tmp_ref_path is not None and os.path.isfile(tmp_ref_path):
|
| 323 |
+
os.remove(tmp_ref_path)
|
| 324 |
|
| 325 |
|
| 326 |
def get_gpu_stats():
|
|
|
|
| 475 |
with gr.Row():
|
| 476 |
denoise_btn = gr.Button(value="Denoise")
|
| 477 |
cancel_btn = gr.Button(value="Cancel Denoise")
|
| 478 |
+
# Use numpy type here so we can reuse the waveform directly in Python.
|
| 479 |
+
denoise_audio = gr.Audio(
|
| 480 |
+
label="Denoised Audio",
|
| 481 |
+
value=None,
|
| 482 |
+
type="numpy",
|
| 483 |
+
interactive=True,
|
| 484 |
+
show_download_button=True,
|
| 485 |
+
editable=True,
|
| 486 |
+
)
|
| 487 |
|
| 488 |
gen_text = gr.Textbox(label="Text to Generate", placeholder="Enter the text you want to generate...")
|
| 489 |
|