File size: 4,252 Bytes
65a2251 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import gradio as gr
import torch
import torchaudio
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
from omegaconf import DictConfig, ListConfig, OmegaConf
# Allowlist omegaconf types before any torch.load calls
import torch.serialization
torch.serialization.add_safe_globals([
ListConfig,
DictConfig,
OmegaConf,
])
from miipher_2.model.feature_cleaner import FeatureCleaner
from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule
MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1"
ADAPTER_FILENAME = "checkpoint_199k_fixed.pt"
VOCODER_FILENAME = "epoch=77-step=137108.ckpt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE_INPUT = 16000
SAMPLE_RATE_OUTPUT = 22050
models_cache = {}
def download_models():
adapter_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=ADAPTER_FILENAME, cache_dir="./models"
)
vocoder_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=VOCODER_FILENAME, cache_dir="./models"
)
return adapter_path, vocoder_path
def load_models():
if "cleaner" in models_cache and "vocoder" in models_cache:
return models_cache["cleaner"], models_cache["vocoder"]
adapter_path, vocoder_path = download_models()
model_config = DictConfig({
"hubert_model_name": "utter-project/mHuBERT-147",
"hubert_layer": 6,
"adapter_hidden_dim": 768,
})
print("Loading FeatureCleaner...")
cleaner = FeatureCleaner(model_config).to(DEVICE).eval()
adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False)
cleaner.load_state_dict(adapter_checkpoint["model_state_dict"])
print("Loading vocoder...")
vocoder = HiFiGANLightningModule.load_from_checkpoint(
vocoder_path, map_location=DEVICE, weights_only=False
).to(DEVICE).eval()
models_cache["cleaner"] = cleaner
models_cache["vocoder"] = vocoder
return cleaner, vocoder
@torch.inference_mode()
def enhance_audio(audio_path, progress=gr.Progress()):
try:
progress(0, desc="Loading models...")
cleaner, vocoder = load_models()
progress(0.2, desc="Loading audio...")
waveform, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE_INPUT:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT)
waveform = waveform.mean(0, keepdim=True).to(DEVICE)
progress(0.4, desc="Extracting features...")
use_amp = DEVICE.type == "cuda"
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=use_amp):
features = cleaner(waveform)
if features.dim() == 2:
features = features.unsqueeze(0)
progress(0.7, desc="Generating enhanced audio...")
batch = {"input_feature": features.transpose(1, 2)}
enhanced_audio = vocoder.generator_forward(batch)
enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy()
enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0)
enhanced_tensor = torch.from_numpy(enhanced_audio)
if enhanced_tensor.dim() == 1:
enhanced_tensor = enhanced_tensor.unsqueeze(0)
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, enhanced_tensor, SAMPLE_RATE_OUTPUT)
progress(1.0, desc="Done!")
return tmp.name
except Exception as e:
raise gr.Error(f"Enhancement failed: {str(e)}")
with gr.Blocks(title="Miipher-2 Speech Enhancement", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐ค Miipher-2 Speech Enhancement")
gr.Markdown("Upload noisy speech โ get clean speech back.")
with gr.Row():
input_audio = gr.Audio(label="Input (noisy)", type="filepath", sources=["upload", "microphone"])
output_audio = gr.Audio(label="Enhanced output", type="filepath", interactive=False)
gr.Button("๐ Enhance", variant="primary").click(
fn=enhance_audio, inputs=input_audio, outputs=output_audio, show_progress=True
)
if __name__ == "__main__":
load_models()
demo.launch() |