miipher / app.py
dkakaie's picture
Create app.py
65a2251 verified
import gradio as gr
import torch
import torchaudio
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
from omegaconf import DictConfig, ListConfig, OmegaConf
# Allowlist omegaconf types before any torch.load calls
import torch.serialization
torch.serialization.add_safe_globals([
ListConfig,
DictConfig,
OmegaConf,
])
from miipher_2.model.feature_cleaner import FeatureCleaner
from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule
MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1"
ADAPTER_FILENAME = "checkpoint_199k_fixed.pt"
VOCODER_FILENAME = "epoch=77-step=137108.ckpt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE_INPUT = 16000
SAMPLE_RATE_OUTPUT = 22050
models_cache = {}
def download_models():
adapter_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=ADAPTER_FILENAME, cache_dir="./models"
)
vocoder_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=VOCODER_FILENAME, cache_dir="./models"
)
return adapter_path, vocoder_path
def load_models():
if "cleaner" in models_cache and "vocoder" in models_cache:
return models_cache["cleaner"], models_cache["vocoder"]
adapter_path, vocoder_path = download_models()
model_config = DictConfig({
"hubert_model_name": "utter-project/mHuBERT-147",
"hubert_layer": 6,
"adapter_hidden_dim": 768,
})
print("Loading FeatureCleaner...")
cleaner = FeatureCleaner(model_config).to(DEVICE).eval()
adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False)
cleaner.load_state_dict(adapter_checkpoint["model_state_dict"])
print("Loading vocoder...")
vocoder = HiFiGANLightningModule.load_from_checkpoint(
vocoder_path, map_location=DEVICE, weights_only=False
).to(DEVICE).eval()
models_cache["cleaner"] = cleaner
models_cache["vocoder"] = vocoder
return cleaner, vocoder
@torch.inference_mode()
def enhance_audio(audio_path, progress=gr.Progress()):
try:
progress(0, desc="Loading models...")
cleaner, vocoder = load_models()
progress(0.2, desc="Loading audio...")
waveform, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE_INPUT:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT)
waveform = waveform.mean(0, keepdim=True).to(DEVICE)
progress(0.4, desc="Extracting features...")
use_amp = DEVICE.type == "cuda"
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=use_amp):
features = cleaner(waveform)
if features.dim() == 2:
features = features.unsqueeze(0)
progress(0.7, desc="Generating enhanced audio...")
batch = {"input_feature": features.transpose(1, 2)}
enhanced_audio = vocoder.generator_forward(batch)
enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy()
enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0)
enhanced_tensor = torch.from_numpy(enhanced_audio)
if enhanced_tensor.dim() == 1:
enhanced_tensor = enhanced_tensor.unsqueeze(0)
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, enhanced_tensor, SAMPLE_RATE_OUTPUT)
progress(1.0, desc="Done!")
return tmp.name
except Exception as e:
raise gr.Error(f"Enhancement failed: {str(e)}")
with gr.Blocks(title="Miipher-2 Speech Enhancement", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎀 Miipher-2 Speech Enhancement")
gr.Markdown("Upload noisy speech β€” get clean speech back.")
with gr.Row():
input_audio = gr.Audio(label="Input (noisy)", type="filepath", sources=["upload", "microphone"])
output_audio = gr.Audio(label="Enhanced output", type="filepath", interactive=False)
gr.Button("πŸš€ Enhance", variant="primary").click(
fn=enhance_audio, inputs=input_audio, outputs=output_audio, show_progress=True
)
if __name__ == "__main__":
load_models()
demo.launch()