Spaces:
Running
on
Zero
Running
on
Zero
Update inference_gradio.py
Browse files- inference_gradio.py +43 -18
inference_gradio.py
CHANGED
|
@@ -75,14 +75,18 @@ class UVR5:
|
|
| 75 |
"""Small wrapper around the bundled uvr5 implementation for denoising."""
|
| 76 |
|
| 77 |
def __init__(self, model_dir):
|
|
|
|
| 78 |
code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
|
| 79 |
self.model = self.load_model(model_dir, code_dir)
|
| 80 |
|
| 81 |
def load_model(self, model_dir, code_dir):
|
| 82 |
-
import sys, json
|
| 83 |
if code_dir not in sys.path:
|
| 84 |
sys.path.append(code_dir)
|
| 85 |
from multiprocess_cuda_infer import ModelData, Inference
|
|
|
|
|
|
|
|
|
|
| 86 |
model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
|
| 87 |
config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
|
| 88 |
with open(config_path, "r", encoding="utf-8") as f:
|
|
@@ -93,7 +97,7 @@ class UVR5:
|
|
| 93 |
result_path = model_dir,
|
| 94 |
device = 'cpu',
|
| 95 |
process_method = "MDX-Net",
|
| 96 |
-
base_dir=
|
| 97 |
**configs
|
| 98 |
)
|
| 99 |
|
|
@@ -390,11 +394,12 @@ class MMSAlignModel:
|
|
| 390 |
|
| 391 |
class WhisperxModel:
|
| 392 |
def __init__(self, model_name):
|
| 393 |
-
from whisperx import load_model
|
| 394 |
from pathlib import Path
|
|
|
|
|
|
|
| 395 |
prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
|
| 396 |
|
| 397 |
-
# Prefer a local VAD model (to avoid network download /
|
| 398 |
vad_fp = Path(MODELS_PATH) / "whisperx-vad-segmentation.bin"
|
| 399 |
if not vad_fp.is_file():
|
| 400 |
logging.warning(
|
|
@@ -402,6 +407,30 @@ class WhisperxModel:
|
|
| 402 |
vad_fp,
|
| 403 |
)
|
| 404 |
vad_fp = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
self.model = load_model(
|
| 407 |
model_name,
|
|
@@ -515,21 +544,17 @@ def get_audio_slice(audio, words_info, start_time, end_time, max_len=10, sr=1600
|
|
| 515 |
def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
|
| 516 |
|
| 517 |
global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
|
| 518 |
-
# if voicecraft_model:
|
| 519 |
-
# del denoise_model
|
| 520 |
-
# del transcribe_model
|
| 521 |
-
# del align_model
|
| 522 |
-
# del voicecraft_model
|
| 523 |
-
# del audiosr
|
| 524 |
torch.cuda.empty_cache()
|
| 525 |
gc.collect()
|
| 526 |
|
| 527 |
if denoise_model_name == "UVR5":
|
| 528 |
-
#
|
| 529 |
-
#
|
| 530 |
-
#
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
| 533 |
elif denoise_model_name == "DeepFilterNet":
|
| 534 |
denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
|
| 535 |
|
|
@@ -1177,10 +1202,10 @@ def get_app():
|
|
| 1177 |
if __name__ == "__main__":
|
| 1178 |
import argparse
|
| 1179 |
|
| 1180 |
-
parser = argparse.ArgumentParser(description="
|
| 1181 |
|
| 1182 |
-
parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
|
| 1183 |
-
parser.add_argument("--tmp-path", default="./pretrained_models/
|
| 1184 |
parser.add_argument("--port", default=41020, type=int, help="App port")
|
| 1185 |
parser.add_argument("--share", action="store_true", help="Launch with public url")
|
| 1186 |
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
|
|
|
|
| 75 |
"""Small wrapper around the bundled uvr5 implementation for denoising."""
|
| 76 |
|
| 77 |
def __init__(self, model_dir):
|
| 78 |
+
# Code directory is always the local `uvr5` folder in this repo
|
| 79 |
code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
|
| 80 |
self.model = self.load_model(model_dir, code_dir)
|
| 81 |
|
| 82 |
def load_model(self, model_dir, code_dir):
|
| 83 |
+
import sys, json, os
|
| 84 |
if code_dir not in sys.path:
|
| 85 |
sys.path.append(code_dir)
|
| 86 |
from multiprocess_cuda_infer import ModelData, Inference
|
| 87 |
+
# In the minimal LEMAS-TTS layout, UVR5 weights live under:
|
| 88 |
+
# <pretrained_models>/uvr5/models/MDX_Net_Models/model_data/
|
| 89 |
+
# Here `model_dir` points to that `model_data` directory.
|
| 90 |
model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
|
| 91 |
config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
|
| 92 |
with open(config_path, "r", encoding="utf-8") as f:
|
|
|
|
| 97 |
result_path = model_dir,
|
| 98 |
device = 'cpu',
|
| 99 |
process_method = "MDX-Net",
|
| 100 |
+
base_dir=code_dir,
|
| 101 |
**configs
|
| 102 |
)
|
| 103 |
|
|
|
|
| 394 |
|
| 395 |
class WhisperxModel:
|
| 396 |
def __init__(self, model_name):
|
|
|
|
| 397 |
from pathlib import Path
|
| 398 |
+
import whisperx.vad as wx_vad
|
| 399 |
+
from whisperx import load_model
|
| 400 |
prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
|
| 401 |
|
| 402 |
+
# Prefer a local VAD model (to avoid network download / checksum issues)
|
| 403 |
vad_fp = Path(MODELS_PATH) / "whisperx-vad-segmentation.bin"
|
| 404 |
if not vad_fp.is_file():
|
| 405 |
logging.warning(
|
|
|
|
| 407 |
vad_fp,
|
| 408 |
)
|
| 409 |
vad_fp = None
|
| 410 |
+
else:
|
| 411 |
+
# Monkey-patch whisperx.vad.load_vad_model so it loads our local
|
| 412 |
+
# segmentation model without enforcing the baked-in SHA256 check.
|
| 413 |
+
def _patched_load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
| 414 |
+
import torch
|
| 415 |
+
from pyannote.audio import Model
|
| 416 |
+
from pyannote.audio.pipelines import VoiceActivitySegmentation
|
| 417 |
+
|
| 418 |
+
model_path = str(model_fp) if model_fp is not None else str(vad_fp)
|
| 419 |
+
model = Model.from_pretrained(model_path, use_auth_token=use_auth_token)
|
| 420 |
+
hyperparameters = {
|
| 421 |
+
"onset": vad_onset,
|
| 422 |
+
"offset": vad_offset,
|
| 423 |
+
"min_duration_on": 0.1,
|
| 424 |
+
"min_duration_off": 0.1,
|
| 425 |
+
}
|
| 426 |
+
vad_pipeline = VoiceActivitySegmentation(
|
| 427 |
+
segmentation=model,
|
| 428 |
+
device=torch.device(device),
|
| 429 |
+
)
|
| 430 |
+
vad_pipeline.instantiate(hyperparameters)
|
| 431 |
+
return vad_pipeline
|
| 432 |
+
|
| 433 |
+
wx_vad.load_vad_model = _patched_load_vad_model
|
| 434 |
|
| 435 |
self.model = load_model(
|
| 436 |
model_name,
|
|
|
|
| 544 |
def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
|
| 545 |
|
| 546 |
global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
torch.cuda.empty_cache()
|
| 548 |
gc.collect()
|
| 549 |
|
| 550 |
if denoise_model_name == "UVR5":
|
| 551 |
+
# Follow LEMAS-TTS layout but resolve from MODELS_PATH (./pretrained_models by default),
|
| 552 |
+
# so that only the main TTS checkpoints can live in hf:// mounts while all
|
| 553 |
+
# auxiliary models (UVR5, vocoder, prosody encoder, etc.) are loaded from
|
| 554 |
+
# the local `pretrained_models` folder.
|
| 555 |
+
from pathlib import Path
|
| 556 |
+
uv_root = Path(MODELS_PATH) / "uvr5" / "models" / "MDX_Net_Models" / "model_data"
|
| 557 |
+
denoise_model = UVR5(str(uv_root))
|
| 558 |
elif denoise_model_name == "DeepFilterNet":
|
| 559 |
denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
|
| 560 |
|
|
|
|
| 1202 |
if __name__ == "__main__":
|
| 1203 |
import argparse
|
| 1204 |
|
| 1205 |
+
parser = argparse.ArgumentParser(description="LEMAS-Edit gradio app.")
|
| 1206 |
|
| 1207 |
+
parser.add_argument("--demo-path", default="./pretrained_models/demo", help="Path to demo directory")
|
| 1208 |
+
parser.add_argument("--tmp-path", default="./pretrained_models/tmp", help="Path to tmp directory")
|
| 1209 |
parser.add_argument("--port", default=41020, type=int, help="App port")
|
| 1210 |
parser.add_argument("--share", action="store_true", help="Launch with public url")
|
| 1211 |
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
|