Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| docker build -t denoise:v20250626_1616 . | |
| docker stop denoise_7865 && docker rm denoise_7865 | |
| docker run -itd \ | |
| --name denoise_7865 \ | |
| --restart=always \ | |
| --network host \ | |
| -e server_port=7865 \ | |
| -e hf_token=hf_coRVvzwAzCwGHKRK***********EX \ | |
| denoise:v20250609_1919 /bin/bash | |
| """ | |
| import argparse | |
| import json | |
| from functools import lru_cache | |
| import logging | |
| from pathlib import Path | |
| import platform | |
| import shutil | |
| import tempfile | |
| import time | |
| from typing import Dict, Tuple | |
| import uuid | |
| import zipfile | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import librosa | |
| import librosa.display | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scipy.io import wavfile | |
| import log | |
| from project_settings import environment, project_path, log_directory | |
| from toolbox.os.command import Command | |
| from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2 | |
| from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN | |
| from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN | |
| from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet | |
| log.setup_size_rotating(log_directory=log_directory) | |
| logger = logging.getLogger("main") | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--examples_dir", | |
| # default=(project_path / "data").as_posix(), | |
| default=(project_path / "data/examples").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--models_repo_id", | |
| default="qgyd2021/cc_denoise", | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--trained_model_dir", | |
| default=(project_path / "trained_models").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--hf_token", | |
| default=environment.get("hf_token"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--server_port", | |
| default=environment.get("server_port", 7860), | |
| type=int | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def save_input_audio(sample_rate: int, signal: np.ndarray) -> str: | |
| if signal.dtype != np.int16: | |
| raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}") | |
| temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio" | |
| temp_audio_dir.mkdir(parents=True, exist_ok=True) | |
| filename = temp_audio_dir / f"{uuid.uuid4()}.wav" | |
| filename = filename.as_posix() | |
| wavfile.write( | |
| filename, | |
| sample_rate, signal | |
| ) | |
| return filename | |
| def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int): | |
| filename = save_input_audio(sample_rate, signal) | |
| signal, _ = librosa.load(filename, sr=target_sample_rate) | |
| signal = np.array(signal * (1 << 15), dtype=np.int16) | |
| return signal | |
| def shell(cmd: str): | |
| return Command.popen(cmd) | |
| def get_infer_cls_by_model_name(model_name: str): | |
| if model_name.__contains__("dtln"): | |
| infer_cls = InferenceDTLN | |
| elif model_name.__contains__("dfnet2"): | |
| infer_cls = InferenceDfNet2 | |
| elif model_name.__contains__("frcrn"): | |
| infer_cls = InferenceFRCRN | |
| elif model_name.__contains__("mpnet"): | |
| infer_cls = InferenceMPNet | |
| else: | |
| raise AssertionError | |
| return infer_cls | |
| denoise_engines: Dict[str, dict] = None | |
| def load_denoise_model(infer_cls, **kwargs): | |
| infer_engine = infer_cls(**kwargs) | |
| return infer_engine | |
| def generate_spectrogram(signal: np.ndarray, sample_rate: int = 8000, title: str = "Spectrogram"): | |
| mag = np.abs(librosa.stft(signal)) | |
| # mag_db = librosa.amplitude_to_db(mag, ref=np.max) | |
| mag_db = librosa.amplitude_to_db(mag, ref=20) | |
| plt.figure(figsize=(10, 4)) | |
| librosa.display.specshow(mag_db, sr=sample_rate) | |
| plt.title(title) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| plt.savefig(temp_file.name, bbox_inches="tight") | |
| plt.close() | |
| return temp_file.name | |
| def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_t = None, engine: str = None): | |
| if noisy_audio_file_t is None and noisy_audio_microphone_t is None: | |
| raise gr.Error(f"audio file and microphone is null.") | |
| if noisy_audio_file_t is not None and noisy_audio_microphone_t is not None: | |
| gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") | |
| noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t | |
| sample_rate, signal = noisy_audio_t | |
| if sample_rate != 8000: | |
| signal = convert_sample_rate(signal, sample_rate, 8000) | |
| sample_rate = 8000 | |
| audio_duration = signal.shape[-1] // 8000 | |
| # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。 | |
| logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}") | |
| noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) | |
| infer_engine_param = denoise_engines.get(engine) | |
| if infer_engine_param is None: | |
| raise gr.Error(f"invalid denoise engine: {engine}.") | |
| try: | |
| infer_cls = infer_engine_param["infer_cls"] | |
| kwargs = infer_engine_param["kwargs"] | |
| infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs) | |
| begin = time.time() | |
| denoise_audio = infer_engine.enhancement_by_ndarray(noisy_audio) | |
| time_cost = time.time() - begin | |
| fpr = time_cost / audio_duration | |
| info = { | |
| "time_cost": round(time_cost, 4), | |
| "audio_duration": round(audio_duration, 4), | |
| "fpr": round(fpr, 4) | |
| } | |
| message = json.dumps(info, ensure_ascii=False, indent=4) | |
| noise_audio = noisy_audio - denoise_audio | |
| noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy") | |
| denoise_mag_db = generate_spectrogram(denoise_audio, title="denoise") | |
| noise_mag_db = generate_spectrogram(noise_audio, title="noise") | |
| denoise_audio = np.array(denoise_audio * (1 << 15), dtype=np.int16) | |
| noise_audio = np.array(noise_audio * (1 << 15), dtype=np.int16) | |
| except Exception as e: | |
| raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") | |
| denoise_audio_t = (sample_rate, denoise_audio) | |
| noise_audio_t = (sample_rate, noise_audio) | |
| return denoise_audio_t, noise_audio_t, message, noisy_mag_db, denoise_mag_db, noise_mag_db | |
| def main(): | |
| args = get_args() | |
| examples_dir = Path(args.examples_dir) | |
| trained_model_dir = Path(args.trained_model_dir) | |
| # download models | |
| if not trained_model_dir.exists(): | |
| trained_model_dir.mkdir(parents=True, exist_ok=True) | |
| _ = snapshot_download( | |
| repo_id=args.models_repo_id, | |
| local_dir=trained_model_dir.as_posix(), | |
| token=args.hf_token, | |
| ) | |
| # engines | |
| global denoise_engines | |
| denoise_engines = { | |
| filename.stem: { | |
| "infer_cls": get_infer_cls_by_model_name(filename.stem), | |
| "kwargs": { | |
| "pretrained_model_path_or_zip_file": filename.as_posix() | |
| } | |
| } | |
| for filename in (project_path / "trained_models").glob("*.zip") | |
| if filename.name != "examples.zip" | |
| } | |
| # choices | |
| denoise_engine_choices = list(denoise_engines.keys()) | |
| # examples | |
| if not examples_dir.exists(): | |
| example_zip_file = trained_model_dir / "examples.zip" | |
| with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: | |
| out_root = examples_dir | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| # examples | |
| examples = list() | |
| for filename in examples_dir.glob("**/*.wav"): | |
| examples.append([ | |
| filename.as_posix(), | |
| None, | |
| denoise_engine_choices[0], | |
| ]) | |
| # ui | |
| with gr.Blocks() as blocks: | |
| gr.Markdown(value="denoise.") | |
| with gr.Tabs(): | |
| with gr.TabItem("denoise"): | |
| with gr.Row(): | |
| with gr.Column(variant="panel", scale=5): | |
| with gr.Tabs(): | |
| with gr.TabItem("file"): | |
| dn_noisy_audio_file = gr.Audio(label="noisy_audio") | |
| with gr.TabItem("microphone"): | |
| dn_noisy_audio_microphone = gr.Audio(sources="microphone", label="noisy_audio") | |
| dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") | |
| dn_button = gr.Button(variant="primary") | |
| with gr.Column(variant="panel", scale=5): | |
| with gr.Tabs(): | |
| with gr.TabItem("audio"): | |
| dn_denoise_audio = gr.Audio(label="denoise_audio") | |
| dn_noise_audio = gr.Audio(label="noise_audio") | |
| dn_message = gr.Textbox(lines=1, max_lines=20, label="message") | |
| with gr.TabItem("mag_db"): | |
| dn_noisy_mag_db = gr.Image(label="noisy_mag_db") | |
| dn_denoise_mag_db = gr.Image(label="denoise_mag_db") | |
| dn_noise_mag_db = gr.Image(label="noise_mag_db") | |
| dn_button.click( | |
| when_click_denoise_button, | |
| inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], | |
| outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db] | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], | |
| outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db], | |
| fn=when_click_denoise_button, | |
| # cache_examples=True, | |
| # cache_mode="lazy", | |
| ) | |
| with gr.TabItem("shell"): | |
| shell_text = gr.Textbox(label="cmd") | |
| shell_button = gr.Button("run") | |
| shell_output = gr.Textbox(label="output") | |
| shell_button.click( | |
| shell, | |
| inputs=[shell_text,], | |
| outputs=[shell_output], | |
| ) | |
| # http://127.0.0.1:7865/ | |
| # http://10.75.27.247:7865/ | |
| blocks.queue().launch( | |
| # share=True, | |
| share=False if platform.system() == "Windows" else False, | |
| server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
| server_port=args.server_port | |
| ) | |
| return | |
| if __name__ == "__main__": | |
| main() | |