| import gradio as gr |
| import torch |
| import torchaudio |
| import numpy as np |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
| |
| import torch.serialization |
| torch.serialization.add_safe_globals([ |
| ListConfig, |
| DictConfig, |
| OmegaConf, |
| ]) |
|
|
| from miipher_2.model.feature_cleaner import FeatureCleaner |
| from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule |
|
|
| MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1" |
| ADAPTER_FILENAME = "checkpoint_199k_fixed.pt" |
| VOCODER_FILENAME = "epoch=77-step=137108.ckpt" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| SAMPLE_RATE_INPUT = 16000 |
| SAMPLE_RATE_OUTPUT = 22050 |
|
|
| models_cache = {} |
|
|
|
|
| def download_models(): |
| adapter_path = hf_hub_download( |
| repo_id=MODEL_REPO_ID, filename=ADAPTER_FILENAME, cache_dir="./models" |
| ) |
| vocoder_path = hf_hub_download( |
| repo_id=MODEL_REPO_ID, filename=VOCODER_FILENAME, cache_dir="./models" |
| ) |
| return adapter_path, vocoder_path |
|
|
|
|
| def load_models(): |
| if "cleaner" in models_cache and "vocoder" in models_cache: |
| return models_cache["cleaner"], models_cache["vocoder"] |
|
|
| adapter_path, vocoder_path = download_models() |
|
|
| model_config = DictConfig({ |
| "hubert_model_name": "utter-project/mHuBERT-147", |
| "hubert_layer": 6, |
| "adapter_hidden_dim": 768, |
| }) |
|
|
| print("Loading FeatureCleaner...") |
| cleaner = FeatureCleaner(model_config).to(DEVICE).eval() |
| adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False) |
| cleaner.load_state_dict(adapter_checkpoint["model_state_dict"]) |
|
|
| print("Loading vocoder...") |
| vocoder = HiFiGANLightningModule.load_from_checkpoint( |
| vocoder_path, map_location=DEVICE, weights_only=False |
| ).to(DEVICE).eval() |
|
|
| models_cache["cleaner"] = cleaner |
| models_cache["vocoder"] = vocoder |
| return cleaner, vocoder |
|
|
|
|
| @torch.inference_mode() |
| def enhance_audio(audio_path, progress=gr.Progress()): |
| try: |
| progress(0, desc="Loading models...") |
| cleaner, vocoder = load_models() |
|
|
| progress(0.2, desc="Loading audio...") |
| waveform, sr = torchaudio.load(audio_path) |
| if sr != SAMPLE_RATE_INPUT: |
| waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT) |
| waveform = waveform.mean(0, keepdim=True).to(DEVICE) |
|
|
| progress(0.4, desc="Extracting features...") |
| use_amp = DEVICE.type == "cuda" |
| with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=use_amp): |
| features = cleaner(waveform) |
| if features.dim() == 2: |
| features = features.unsqueeze(0) |
|
|
| progress(0.7, desc="Generating enhanced audio...") |
| batch = {"input_feature": features.transpose(1, 2)} |
| enhanced_audio = vocoder.generator_forward(batch) |
|
|
| enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy() |
| enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0) |
| enhanced_tensor = torch.from_numpy(enhanced_audio) |
| if enhanced_tensor.dim() == 1: |
| enhanced_tensor = enhanced_tensor.unsqueeze(0) |
|
|
| import tempfile |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| torchaudio.save(tmp.name, enhanced_tensor, SAMPLE_RATE_OUTPUT) |
| progress(1.0, desc="Done!") |
| return tmp.name |
|
|
| except Exception as e: |
| raise gr.Error(f"Enhancement failed: {str(e)}") |
|
|
|
|
| with gr.Blocks(title="Miipher-2 Speech Enhancement", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π€ Miipher-2 Speech Enhancement") |
| gr.Markdown("Upload noisy speech β get clean speech back.") |
|
|
| with gr.Row(): |
| input_audio = gr.Audio(label="Input (noisy)", type="filepath", sources=["upload", "microphone"]) |
| output_audio = gr.Audio(label="Enhanced output", type="filepath", interactive=False) |
|
|
| gr.Button("π Enhance", variant="primary").click( |
| fn=enhance_audio, inputs=input_audio, outputs=output_audio, show_progress=True |
| ) |
|
|
| if __name__ == "__main__": |
| load_models() |
| demo.launch() |