import gradio as gr import torch import torchaudio import numpy as np from pathlib import Path from huggingface_hub import hf_hub_download from omegaconf import DictConfig, ListConfig, OmegaConf # Allowlist omegaconf types before any torch.load calls import torch.serialization torch.serialization.add_safe_globals([ ListConfig, DictConfig, OmegaConf, ]) from miipher_2.model.feature_cleaner import FeatureCleaner from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1" ADAPTER_FILENAME = "checkpoint_199k_fixed.pt" VOCODER_FILENAME = "epoch=77-step=137108.ckpt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SAMPLE_RATE_INPUT = 16000 SAMPLE_RATE_OUTPUT = 22050 models_cache = {} def download_models(): adapter_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=ADAPTER_FILENAME, cache_dir="./models" ) vocoder_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=VOCODER_FILENAME, cache_dir="./models" ) return adapter_path, vocoder_path def load_models(): if "cleaner" in models_cache and "vocoder" in models_cache: return models_cache["cleaner"], models_cache["vocoder"] adapter_path, vocoder_path = download_models() model_config = DictConfig({ "hubert_model_name": "utter-project/mHuBERT-147", "hubert_layer": 6, "adapter_hidden_dim": 768, }) print("Loading FeatureCleaner...") cleaner = FeatureCleaner(model_config).to(DEVICE).eval() adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False) cleaner.load_state_dict(adapter_checkpoint["model_state_dict"]) print("Loading vocoder...") vocoder = HiFiGANLightningModule.load_from_checkpoint( vocoder_path, map_location=DEVICE, weights_only=False ).to(DEVICE).eval() models_cache["cleaner"] = cleaner models_cache["vocoder"] = vocoder return cleaner, vocoder @torch.inference_mode() def enhance_audio(audio_path, progress=gr.Progress()): try: progress(0, desc="Loading models...") cleaner, vocoder = load_models() progress(0.2, desc="Loading audio...") waveform, sr = torchaudio.load(audio_path) if sr != SAMPLE_RATE_INPUT: waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT) waveform = waveform.mean(0, keepdim=True).to(DEVICE) progress(0.4, desc="Extracting features...") use_amp = DEVICE.type == "cuda" with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=use_amp): features = cleaner(waveform) if features.dim() == 2: features = features.unsqueeze(0) progress(0.7, desc="Generating enhanced audio...") batch = {"input_feature": features.transpose(1, 2)} enhanced_audio = vocoder.generator_forward(batch) enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy() enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0) enhanced_tensor = torch.from_numpy(enhanced_audio) if enhanced_tensor.dim() == 1: enhanced_tensor = enhanced_tensor.unsqueeze(0) import tempfile with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, enhanced_tensor, SAMPLE_RATE_OUTPUT) progress(1.0, desc="Done!") return tmp.name except Exception as e: raise gr.Error(f"Enhancement failed: {str(e)}") with gr.Blocks(title="Miipher-2 Speech Enhancement", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎤 Miipher-2 Speech Enhancement") gr.Markdown("Upload noisy speech — get clean speech back.") with gr.Row(): input_audio = gr.Audio(label="Input (noisy)", type="filepath", sources=["upload", "microphone"]) output_audio = gr.Audio(label="Enhanced output", type="filepath", interactive=False) gr.Button("🚀 Enhance", variant="primary").click( fn=enhance_audio, inputs=input_audio, outputs=output_audio, show_progress=True ) if __name__ == "__main__": load_models() demo.launch()