File size: 4,252 Bytes
65a2251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
import torchaudio
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
from omegaconf import DictConfig, ListConfig, OmegaConf

# Allowlist omegaconf types before any torch.load calls
import torch.serialization
torch.serialization.add_safe_globals([
    ListConfig,
    DictConfig,
    OmegaConf,
])

from miipher_2.model.feature_cleaner import FeatureCleaner
from miipher_2.lightning_vocoders.lightning_module import HiFiGANLightningModule

MODEL_REPO_ID = "Atotti/miipher-2-HuBERT-HiFi-GAN-v0.1"
ADAPTER_FILENAME = "checkpoint_199k_fixed.pt"
VOCODER_FILENAME = "epoch=77-step=137108.ckpt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE_INPUT = 16000
SAMPLE_RATE_OUTPUT = 22050

models_cache = {}


def download_models():
    adapter_path = hf_hub_download(
        repo_id=MODEL_REPO_ID, filename=ADAPTER_FILENAME, cache_dir="./models"
    )
    vocoder_path = hf_hub_download(
        repo_id=MODEL_REPO_ID, filename=VOCODER_FILENAME, cache_dir="./models"
    )
    return adapter_path, vocoder_path


def load_models():
    if "cleaner" in models_cache and "vocoder" in models_cache:
        return models_cache["cleaner"], models_cache["vocoder"]

    adapter_path, vocoder_path = download_models()

    model_config = DictConfig({
        "hubert_model_name": "utter-project/mHuBERT-147",
        "hubert_layer": 6,
        "adapter_hidden_dim": 768,
    })

    print("Loading FeatureCleaner...")
    cleaner = FeatureCleaner(model_config).to(DEVICE).eval()
    adapter_checkpoint = torch.load(adapter_path, map_location=DEVICE, weights_only=False)
    cleaner.load_state_dict(adapter_checkpoint["model_state_dict"])

    print("Loading vocoder...")
    vocoder = HiFiGANLightningModule.load_from_checkpoint(
        vocoder_path, map_location=DEVICE, weights_only=False
    ).to(DEVICE).eval()

    models_cache["cleaner"] = cleaner
    models_cache["vocoder"] = vocoder
    return cleaner, vocoder


@torch.inference_mode()
def enhance_audio(audio_path, progress=gr.Progress()):
    try:
        progress(0, desc="Loading models...")
        cleaner, vocoder = load_models()

        progress(0.2, desc="Loading audio...")
        waveform, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE_INPUT:
            waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE_INPUT)
        waveform = waveform.mean(0, keepdim=True).to(DEVICE)

        progress(0.4, desc="Extracting features...")
        use_amp = DEVICE.type == "cuda"
        with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=use_amp):
            features = cleaner(waveform)
            if features.dim() == 2:
                features = features.unsqueeze(0)

            progress(0.7, desc="Generating enhanced audio...")
            batch = {"input_feature": features.transpose(1, 2)}
            enhanced_audio = vocoder.generator_forward(batch)

        enhanced_audio = enhanced_audio.squeeze(0).cpu().to(torch.float32).detach().numpy()
        enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0)
        enhanced_tensor = torch.from_numpy(enhanced_audio)
        if enhanced_tensor.dim() == 1:
            enhanced_tensor = enhanced_tensor.unsqueeze(0)

        import tempfile
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            torchaudio.save(tmp.name, enhanced_tensor, SAMPLE_RATE_OUTPUT)
            progress(1.0, desc="Done!")
            return tmp.name

    except Exception as e:
        raise gr.Error(f"Enhancement failed: {str(e)}")


with gr.Blocks(title="Miipher-2 Speech Enhancement", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ๐ŸŽค Miipher-2 Speech Enhancement")
    gr.Markdown("Upload noisy speech โ€” get clean speech back.")

    with gr.Row():
        input_audio = gr.Audio(label="Input (noisy)", type="filepath", sources=["upload", "microphone"])
        output_audio = gr.Audio(label="Enhanced output", type="filepath", interactive=False)

    gr.Button("๐Ÿš€ Enhance", variant="primary").click(
        fn=enhance_audio, inputs=input_audio, outputs=output_audio, show_progress=True
    )

if __name__ == "__main__":
    load_models()
    demo.launch()