Spaces:
Running
on
Zero
Running
on
Zero
| import shlex | |
| import subprocess | |
| import spaces | |
| import torch | |
| # install packages for mamba | |
| def install_mamba(): | |
| #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) | |
| #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install numpy==1.26.4")) | |
| subprocess.run(shlex.split("ls")) | |
| install_mamba() | |
| import gradio as gr | |
| import torch | |
| import yaml | |
| import librosa | |
| import librosa.display | |
| import matplotlib | |
| from huggingface_hub import hf_hub_download | |
| from models.stfts import mag_phase_stft, mag_phase_istft | |
| from models.generator import SEMamba | |
| from models.pcs400 import cal_pcs | |
| # download model files from your HF repo | |
| #ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
| # "ckpts/SEMamba_advanced.pth") | |
| #cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
| # "recipes/SEMamba_advanced.yaml") | |
| ckpt = "ckpts/SEMamba_advanced.pth" | |
| cfg_f = "recipes/SEMamba_advanced.yaml" | |
| # load config | |
| with open(cfg_f) as f: | |
| cfg = yaml.safe_load(f) | |
| stft_cfg = cfg["stft_cfg"] | |
| model_cfg = cfg["model_cfg"] | |
| sr = stft_cfg["sampling_rate"] | |
| n_fft = stft_cfg["n_fft"] | |
| hop_size = stft_cfg["hop_size"] | |
| win_size = stft_cfg["win_size"] | |
| compress_ff = model_cfg["compress_factor"] | |
| # init model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = SEMamba(cfg).to(device) | |
| sdict = torch.load(ckpt, map_location=device) | |
| model.load_state_dict(sdict["generator"]) | |
| model.eval() | |
| def enhance(audio): | |
| if audio is None: return None, None | |
| orig_sr, wav_np = audio | |
| if orig_sr != sr: | |
| wav_np = librosa.resample(wav_np, orig_sr, sr) | |
| wav = torch.from_numpy(wav_np).float().to(device) | |
| norm = torch.sqrt(len(wav) / torch.sum(wav ** 2)) | |
| wav = (wav * norm).unsqueeze(0) | |
| amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"]) | |
| amp_g, pha_g = model(amp, pha) | |
| out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"]) | |
| out = (out / norm).squeeze().cpu().numpy() | |
| if orig_sr != sr: | |
| out = librosa.resample(out, sr, orig_sr) | |
| # spectrogram | |
| D = librosa.stft(out, n_fft=1024, hop_length=512) | |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax) | |
| ax.set_title("Enhanced Spectrogram") | |
| plt.colorbar(format="%+2.0f dB", ax=ax) | |
| return (orig_sr, out), fig | |
| # --- Layout with Blocks --- | |
| with gr.Blocks(css=".gr-box {border: none !important}") as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>π§ <a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement</h1>") | |
| gr.Markdown("Enhance real-world noisy speech using Mamba. Upload or record an audio clip and view the spectrogram.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record", elem_id="input-audio") | |
| run_btn = gr.Button("Enhance Now π", variant="primary") | |
| with gr.Column(): | |
| enhanced_audio = gr.Audio(label="Enhanced Output", type="numpy") | |
| spec_plot = gr.Plot(label="Spectrogram") | |
| run_btn.click(enhance, inputs=audio_input, outputs=[enhanced_audio, spec_plot]) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/noisy_sample_16k.wav"], | |
| ], | |
| inputs=audio_input, | |
| outputs=[enhanced_audio, spec_plot], | |
| fn=enhance, | |
| cache_examples=True, | |
| label="π Try These Examples" | |
| ) | |
| gr.Markdown("<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>π SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>") | |
| demo.launch() | |