Upload MILFER denoiser v1.0
Browse files- .gitattributes +0 -34
- README.md +159 -0
- milfer.py +200 -0
- run.sh +7 -0
- weights/decoder_state_dict.pt +3 -0
- weights/feature_predictor_config.json +81 -0
- weights/feature_predictor_state_dict.pt +3 -0
- weights/milfer_config.json +18 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- ru
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- audio
|
| 7 |
+
- speech
|
| 8 |
+
- audio-restoration
|
| 9 |
+
- pytorch
|
| 10 |
+
- cuda
|
| 11 |
+
pipeline_tag: audio-to-audio
|
| 12 |
+
license: other
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# MILFER
|
| 16 |
+
|
| 17 |
+
MILFER is a standalone PyTorch audio-to-audio model for speech-preserving audio
|
| 18 |
+
restoration. It takes an input audio file, extracts SSL speech features, and
|
| 19 |
+
reconstructs a 48 kHz waveform with the bundled neural decoder.
|
| 20 |
+
|
| 21 |
+
The bundled checkpoint is `milfer_lora100h_step001000`.
|
| 22 |
+
|
| 23 |
+
## Highlights
|
| 24 |
+
|
| 25 |
+
- Pure PyTorch inference, no TorchScript runtime required.
|
| 26 |
+
- CUDA fp16 inference by default when a CUDA GPU is available.
|
| 27 |
+
- Accepts common audio formats supported by `torchaudio`, including wav and mp3.
|
| 28 |
+
- Emits a mono 48 kHz wav file.
|
| 29 |
+
- Tuned to preserve more game/dialogue sound character than the base checkpoint.
|
| 30 |
+
|
| 31 |
+
## Quick Start
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python milfer.py input.wav output.wav
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
For CUDA fp16:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python milfer.py input.mp3 output.wav --device cuda --precision fp16
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
For repeated inference in the same Python process, compile the feature model and
|
| 44 |
+
run one warm-up pass first:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
python milfer.py input.mp3 output.wav --device cuda --precision fp16 --compile-feature
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
The helper script does the same thing with the local Python environment:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
./run.sh input.wav output.wav --device cuda --precision fp16
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Files
|
| 57 |
+
|
| 58 |
+
```text
|
| 59 |
+
milfer.py
|
| 60 |
+
run.sh
|
| 61 |
+
weights/
|
| 62 |
+
decoder_state_dict.pt
|
| 63 |
+
feature_predictor_config.json
|
| 64 |
+
feature_predictor_state_dict.pt
|
| 65 |
+
milfer_config.json
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Requirements
|
| 69 |
+
|
| 70 |
+
Tested with:
|
| 71 |
+
|
| 72 |
+
- Python 3.10
|
| 73 |
+
- PyTorch 2.6.0 + CUDA 12.4
|
| 74 |
+
- torchaudio 2.6.0
|
| 75 |
+
- transformers
|
| 76 |
+
- soundfile
|
| 77 |
+
- descript-audio-codec
|
| 78 |
+
|
| 79 |
+
## Clean Input Check
|
| 80 |
+
|
| 81 |
+
The table below measures how much MILFER changes already-clean clips. It is a
|
| 82 |
+
sanity check, not a denoising benchmark.
|
| 83 |
+
|
| 84 |
+
Evaluation set: `prompts_5kh`, 250 mono wav clips, 44.1 kHz, 19.0 minutes total.
|
| 85 |
+
Higher is better for STOI, eSTOI, PESQ-WB, and MOS predictors. Lower is better
|
| 86 |
+
for LSD and clipped samples.
|
| 87 |
+
|
| 88 |
+
| Subset | Files | STOI | eSTOI | PESQ-WB | LSD 16 kHz | Clipped Samples |
|
| 89 |
+
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 90 |
+
| all clips | 250 | 0.9241 | 0.8719 | 2.1653 | 11.825 dB | 0.0006% |
|
| 91 |
+
| duration >= 1 s | 232 | 0.9288 | 0.8767 | 2.1917 | 11.683 dB | 0.0005% |
|
| 92 |
+
|
| 93 |
+
No-reference MOS predictors on the original and processed outputs:
|
| 94 |
+
|
| 95 |
+
| Subset | Audio | UTMOS | DistillMOS | NISQA-TTS |
|
| 96 |
+
| --- | --- | ---: | ---: | ---: |
|
| 97 |
+
| all clips | original | 2.9998 | 3.9392 | 3.6311 |
|
| 98 |
+
| all clips | MILFER | 2.9741 | 3.8080 | 3.7021 |
|
| 99 |
+
| all clips | delta | -0.0258 | -0.1313 | +0.0710 |
|
| 100 |
+
| duration >= 1 s | original | 3.0120 | 3.9829 | 3.6603 |
|
| 101 |
+
| duration >= 1 s | MILFER | 2.9977 | 3.8554 | 3.7483 |
|
| 102 |
+
| duration >= 1 s | delta | -0.0143 | -0.1275 | +0.0880 |
|
| 103 |
+
|
| 104 |
+
Very short clips can make intelligibility metrics unstable, so the filtered row
|
| 105 |
+
excludes clips shorter than one second.
|
| 106 |
+
|
| 107 |
+
## Degraded-Input Evaluation
|
| 108 |
+
|
| 109 |
+
For a cleaner-style benchmark, the clean prompts were synthetically degraded and
|
| 110 |
+
then processed with MILFER. Metrics compare either the degraded input or the
|
| 111 |
+
MILFER output against the original clean prompt. The table uses the
|
| 112 |
+
`duration >= 1 s` subset: 232 clips, 18.8 minutes total.
|
| 113 |
+
|
| 114 |
+
Degradation profiles:
|
| 115 |
+
|
| 116 |
+
- `noisy_room`: additive noise, room response, light band-limiting.
|
| 117 |
+
- `radio_clip`: band-pass channel, saturation, quantization, hiss.
|
| 118 |
+
- `mixed_hard`: noise, reverb, band-limiting, downsampling, saturation.
|
| 119 |
+
|
| 120 |
+
Full-reference metrics:
|
| 121 |
+
|
| 122 |
+
| Profile | STOI | eSTOI | PESQ-WB | LSD 16 kHz |
|
| 123 |
+
| --- | ---: | ---: | ---: | ---: |
|
| 124 |
+
| noisy_room degraded | 0.8830 | 0.7128 | 1.2104 | 18.121 dB |
|
| 125 |
+
| noisy_room MILFER | 0.9020 | 0.8145 | 1.7757 | 13.174 dB |
|
| 126 |
+
| noisy_room delta | +0.0190 | +0.1017 | +0.5653 | -4.947 dB |
|
| 127 |
+
| mixed_hard degraded | 0.8617 | 0.7079 | 1.1851 | 23.143 dB |
|
| 128 |
+
| mixed_hard MILFER | 0.8948 | 0.8068 | 1.7321 | 13.904 dB |
|
| 129 |
+
| mixed_hard delta | +0.0331 | +0.0989 | +0.5470 | -9.239 dB |
|
| 130 |
+
| radio_clip degraded | 0.9185 | 0.8528 | 1.8765 | 19.167 dB |
|
| 131 |
+
| radio_clip MILFER | 0.9040 | 0.8397 | 1.9412 | 14.237 dB |
|
| 132 |
+
| radio_clip delta | -0.0145 | -0.0131 | +0.0647 | -4.930 dB |
|
| 133 |
+
|
| 134 |
+
No-reference MOS predictors:
|
| 135 |
+
|
| 136 |
+
| Profile | UTMOS | DistillMOS | NISQA-TTS |
|
| 137 |
+
| --- | ---: | ---: | ---: |
|
| 138 |
+
| noisy_room degraded | 1.4220 | 2.7625 | 1.8844 |
|
| 139 |
+
| noisy_room MILFER | 3.0324 | 3.7896 | 3.7557 |
|
| 140 |
+
| noisy_room delta | +1.6104 | +1.0270 | +1.8713 |
|
| 141 |
+
| mixed_hard degraded | 1.3709 | 2.4796 | 2.1738 |
|
| 142 |
+
| mixed_hard MILFER | 2.9478 | 3.7044 | 3.7243 |
|
| 143 |
+
| mixed_hard delta | +1.5769 | +1.2248 | +1.5505 |
|
| 144 |
+
| radio_clip degraded | 1.4901 | 2.9555 | 2.5081 |
|
| 145 |
+
| radio_clip MILFER | 2.8414 | 3.7817 | 3.5840 |
|
| 146 |
+
| radio_clip delta | +1.3513 | +0.8262 | +1.0759 |
|
| 147 |
+
|
| 148 |
+
## Notes
|
| 149 |
+
|
| 150 |
+
- Input audio is mixed to mono and resampled to 16 kHz for feature extraction.
|
| 151 |
+
- Output is written as mono 48 kHz PCM wav.
|
| 152 |
+
- Very long files can be processed, but peak memory depends on input duration.
|
| 153 |
+
- This is an experimental checkpoint. It can still change ambience, effects,
|
| 154 |
+
music, and non-speech sounds.
|
| 155 |
+
|
| 156 |
+
## License
|
| 157 |
+
|
| 158 |
+
License is not specified in this package. Set the final license field before
|
| 159 |
+
publishing if you need redistributable model weights.
|
milfer.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""MILFER command line inference."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Sequence
|
| 10 |
+
|
| 11 |
+
import soundfile as sf
|
| 12 |
+
import torch
|
| 13 |
+
import torchaudio
|
| 14 |
+
from dac.model.dac import Decoder
|
| 15 |
+
from transformers import Wav2Vec2BertConfig, Wav2Vec2BertModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ROOT = Path(__file__).resolve().parent
|
| 19 |
+
WEIGHTS = ROOT / "weights"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_argparser() -> argparse.ArgumentParser:
|
| 23 |
+
parser = argparse.ArgumentParser(description="Run MILFER audio processing.")
|
| 24 |
+
parser.add_argument("input", type=Path, help="Input audio path")
|
| 25 |
+
parser.add_argument("output", type=Path, help="Output wav path")
|
| 26 |
+
parser.add_argument("--weights", type=Path, default=WEIGHTS)
|
| 27 |
+
parser.add_argument("--device", default="auto", help="auto, cuda, or cpu")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--precision",
|
| 30 |
+
choices=("auto", "fp32", "fp16"),
|
| 31 |
+
default="auto",
|
| 32 |
+
help="auto uses fp16 on CUDA and fp32 on CPU.",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--compile-feature",
|
| 36 |
+
action="store_true",
|
| 37 |
+
help="Compile the feature predictor. Slow first call, faster repeated calls.",
|
| 38 |
+
)
|
| 39 |
+
return parser
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resolve_device(name: str) -> torch.device:
|
| 43 |
+
if name == "auto":
|
| 44 |
+
name = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
+
return torch.device(name)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def use_fp16(precision: str, device: torch.device) -> bool:
|
| 49 |
+
if precision == "fp16":
|
| 50 |
+
if device.type != "cuda":
|
| 51 |
+
raise ValueError("--precision fp16 requires CUDA")
|
| 52 |
+
return True
|
| 53 |
+
if precision == "fp32":
|
| 54 |
+
return False
|
| 55 |
+
return device.type == "cuda"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_audio(path: Path, sample_rate: int) -> tuple[torch.Tensor, float]:
|
| 59 |
+
waveform, source_sr = torchaudio.load(str(path))
|
| 60 |
+
if waveform.ndim == 2 and waveform.shape[0] > 1:
|
| 61 |
+
waveform = waveform.mean(dim=0)
|
| 62 |
+
else:
|
| 63 |
+
waveform = waveform.view(-1)
|
| 64 |
+
waveform = waveform.to(torch.float32).contiguous()
|
| 65 |
+
duration = waveform.numel() / float(source_sr)
|
| 66 |
+
if source_sr != sample_rate:
|
| 67 |
+
waveform = torchaudio.functional.resample(
|
| 68 |
+
waveform,
|
| 69 |
+
source_sr,
|
| 70 |
+
sample_rate,
|
| 71 |
+
).contiguous()
|
| 72 |
+
return waveform, duration
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def extract_features(
|
| 76 |
+
waveforms: Sequence[torch.Tensor],
|
| 77 |
+
device: torch.device,
|
| 78 |
+
sampling_rate: int = 16_000,
|
| 79 |
+
padding_value: float = 1.0,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
mel_features: list[torch.Tensor] = []
|
| 82 |
+
for waveform in waveforms:
|
| 83 |
+
waveform = waveform.to(device=device, dtype=torch.float32)
|
| 84 |
+
if waveform.ndim > 1:
|
| 85 |
+
waveform = waveform[0]
|
| 86 |
+
feature = torchaudio.compliance.kaldi.fbank(
|
| 87 |
+
waveform=waveform.unsqueeze(0),
|
| 88 |
+
sample_frequency=sampling_rate,
|
| 89 |
+
num_mel_bins=80,
|
| 90 |
+
frame_length=25,
|
| 91 |
+
frame_shift=10,
|
| 92 |
+
dither=0.0,
|
| 93 |
+
preemphasis_coefficient=0.97,
|
| 94 |
+
remove_dc_offset=True,
|
| 95 |
+
window_type="povey",
|
| 96 |
+
use_energy=False,
|
| 97 |
+
energy_floor=1.192092955078125e-07,
|
| 98 |
+
)
|
| 99 |
+
mean = feature.mean(0, keepdim=True)
|
| 100 |
+
var = feature.var(0, keepdim=True)
|
| 101 |
+
feature = (feature - mean) / torch.sqrt(var + 1e-5)
|
| 102 |
+
mel_features.append(feature)
|
| 103 |
+
|
| 104 |
+
target_frames = max(feature.shape[0] for feature in mel_features)
|
| 105 |
+
if target_frames % 2:
|
| 106 |
+
target_frames += 1
|
| 107 |
+
batch = torch.full(
|
| 108 |
+
(len(mel_features), target_frames, 80),
|
| 109 |
+
padding_value,
|
| 110 |
+
dtype=torch.float32,
|
| 111 |
+
device=device,
|
| 112 |
+
)
|
| 113 |
+
for index, feature in enumerate(mel_features):
|
| 114 |
+
batch[index, : feature.shape[0]] = feature
|
| 115 |
+
|
| 116 |
+
return batch.reshape(len(mel_features), target_frames // 2, 160)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_feature_predictor(weights: Path, device: torch.device) -> torch.nn.Module:
|
| 120 |
+
config = Wav2Vec2BertConfig.from_json_file(str(weights / "feature_predictor_config.json"))
|
| 121 |
+
model = Wav2Vec2BertModel(config)
|
| 122 |
+
state = torch.load(weights / "feature_predictor_state_dict.pt", map_location="cpu")
|
| 123 |
+
model.load_state_dict(state)
|
| 124 |
+
model.eval()
|
| 125 |
+
return model.to(device)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_decoder(weights: Path, device: torch.device) -> torch.nn.Module:
|
| 129 |
+
with (weights / "milfer_config.json").open("r", encoding="utf-8") as file:
|
| 130 |
+
config = json.load(file)
|
| 131 |
+
decoder_config = config["decoder"]
|
| 132 |
+
decoder = Decoder(
|
| 133 |
+
input_channel=decoder_config["input_channel"],
|
| 134 |
+
channels=decoder_config["channels"],
|
| 135 |
+
rates=decoder_config["rates"],
|
| 136 |
+
)
|
| 137 |
+
state = torch.load(weights / "decoder_state_dict.pt", map_location="cpu")
|
| 138 |
+
decoder.load_state_dict(state)
|
| 139 |
+
for module in decoder.modules():
|
| 140 |
+
try:
|
| 141 |
+
torch.nn.utils.remove_weight_norm(module)
|
| 142 |
+
except (AttributeError, ValueError):
|
| 143 |
+
pass
|
| 144 |
+
decoder.eval()
|
| 145 |
+
return decoder.to(device)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@torch.inference_mode()
|
| 149 |
+
def run(args: argparse.Namespace) -> None:
|
| 150 |
+
torch.set_grad_enabled(False)
|
| 151 |
+
if torch.cuda.is_available():
|
| 152 |
+
torch.backends.cudnn.benchmark = True
|
| 153 |
+
|
| 154 |
+
device = resolve_device(args.device)
|
| 155 |
+
half = use_fp16(args.precision, device)
|
| 156 |
+
waveform, duration = load_audio(args.input, sample_rate=16_000)
|
| 157 |
+
expected_samples = int(round(duration * 48_000))
|
| 158 |
+
|
| 159 |
+
feature_predictor = load_feature_predictor(args.weights, device)
|
| 160 |
+
decoder = load_decoder(args.weights, device)
|
| 161 |
+
if half:
|
| 162 |
+
feature_predictor = feature_predictor.half()
|
| 163 |
+
decoder = decoder.half()
|
| 164 |
+
if args.compile_feature:
|
| 165 |
+
if device.type != "cuda":
|
| 166 |
+
raise ValueError("--compile-feature requires CUDA")
|
| 167 |
+
feature_predictor = torch.compile(
|
| 168 |
+
feature_predictor,
|
| 169 |
+
mode="reduce-overhead",
|
| 170 |
+
fullgraph=False,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
features = extract_features([waveform], device=device)
|
| 174 |
+
if half:
|
| 175 |
+
features = features.half()
|
| 176 |
+
hidden = feature_predictor(input_features=features).last_hidden_state
|
| 177 |
+
restored = decoder(hidden.transpose(1, 2))[0].view(-1).float().cpu()
|
| 178 |
+
|
| 179 |
+
if restored.numel() < expected_samples:
|
| 180 |
+
restored = torch.nn.functional.pad(restored, (0, expected_samples - restored.numel()))
|
| 181 |
+
elif restored.numel() > expected_samples:
|
| 182 |
+
restored = restored[:expected_samples]
|
| 183 |
+
|
| 184 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
sf.write(
|
| 186 |
+
str(args.output),
|
| 187 |
+
restored.clamp(-1.0, 1.0).numpy(),
|
| 188 |
+
48_000,
|
| 189 |
+
subtype="PCM_16",
|
| 190 |
+
)
|
| 191 |
+
print(f"wrote={args.output}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main(argv: Sequence[str] | None = None) -> None:
|
| 195 |
+
args = build_argparser().parse_args(argv)
|
| 196 |
+
run(args)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
run.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
PYTHON_BIN="${MILFER_PYTHON:-${ROOT}/.venv/bin/python}"
|
| 6 |
+
|
| 7 |
+
exec "${PYTHON_BIN}" "${ROOT}/milfer.py" "$@"
|
weights/decoder_state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f97bb316f4dfc4186463dfdb820cd6d9e31159b483ded78385632f41dc34cec8
|
| 3 |
+
size 209835977
|
weights/feature_predictor_config.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.0,
|
| 3 |
+
"adapter_act": "relu",
|
| 4 |
+
"adapter_kernel_size": 3,
|
| 5 |
+
"adapter_stride": 2,
|
| 6 |
+
"add_adapter": false,
|
| 7 |
+
"apply_spec_augment": false,
|
| 8 |
+
"architectures": [
|
| 9 |
+
"Wav2Vec2BertModel"
|
| 10 |
+
],
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"classifier_proj_size": 768,
|
| 14 |
+
"codevector_dim": 768,
|
| 15 |
+
"conformer_conv_dropout": 0.1,
|
| 16 |
+
"contrastive_logits_temperature": 0.1,
|
| 17 |
+
"conv_depthwise_kernel_size": 31,
|
| 18 |
+
"ctc_loss_reduction": "sum",
|
| 19 |
+
"ctc_zero_infinity": false,
|
| 20 |
+
"diversity_loss_weight": 0.1,
|
| 21 |
+
"dtype": "float32",
|
| 22 |
+
"eos_token_id": 2,
|
| 23 |
+
"feat_proj_dropout": 0.0,
|
| 24 |
+
"feat_quantizer_dropout": 0.0,
|
| 25 |
+
"feature_projection_input_dim": 160,
|
| 26 |
+
"final_dropout": 0.1,
|
| 27 |
+
"hidden_act": "swish",
|
| 28 |
+
"hidden_dropout": 0.0,
|
| 29 |
+
"hidden_size": 1024,
|
| 30 |
+
"initializer_range": 0.02,
|
| 31 |
+
"intermediate_size": 4096,
|
| 32 |
+
"layer_norm_eps": 1e-05,
|
| 33 |
+
"layerdrop": 0.0,
|
| 34 |
+
"left_max_position_embeddings": 64,
|
| 35 |
+
"mask_feature_length": 10,
|
| 36 |
+
"mask_feature_min_masks": 0,
|
| 37 |
+
"mask_feature_prob": 0.0,
|
| 38 |
+
"mask_time_length": 10,
|
| 39 |
+
"mask_time_min_masks": 2,
|
| 40 |
+
"mask_time_prob": 0.05,
|
| 41 |
+
"max_source_positions": 5000,
|
| 42 |
+
"model_type": "wav2vec2-bert",
|
| 43 |
+
"num_adapter_layers": 1,
|
| 44 |
+
"num_attention_heads": 16,
|
| 45 |
+
"num_codevector_groups": 2,
|
| 46 |
+
"num_codevectors_per_group": 320,
|
| 47 |
+
"num_hidden_layers": 8,
|
| 48 |
+
"num_negatives": 100,
|
| 49 |
+
"output_hidden_size": 1024,
|
| 50 |
+
"pad_token_id": 0,
|
| 51 |
+
"position_embeddings_type": "relative_key",
|
| 52 |
+
"proj_codevector_dim": 768,
|
| 53 |
+
"right_max_position_embeddings": 8,
|
| 54 |
+
"rotary_embedding_base": 10000,
|
| 55 |
+
"tdnn_dilation": [
|
| 56 |
+
1,
|
| 57 |
+
2,
|
| 58 |
+
3,
|
| 59 |
+
1,
|
| 60 |
+
1
|
| 61 |
+
],
|
| 62 |
+
"tdnn_dim": [
|
| 63 |
+
512,
|
| 64 |
+
512,
|
| 65 |
+
512,
|
| 66 |
+
512,
|
| 67 |
+
1500
|
| 68 |
+
],
|
| 69 |
+
"tdnn_kernel": [
|
| 70 |
+
5,
|
| 71 |
+
3,
|
| 72 |
+
3,
|
| 73 |
+
1,
|
| 74 |
+
1
|
| 75 |
+
],
|
| 76 |
+
"transformers_version": "5.9.0",
|
| 77 |
+
"use_intermediate_ffn_before_adapter": false,
|
| 78 |
+
"use_weighted_layer_sum": false,
|
| 79 |
+
"vocab_size": null,
|
| 80 |
+
"xvector_output_dim": 512
|
| 81 |
+
}
|
weights/feature_predictor_state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8837be29d1a70cb00c8b295ae520dab1ff411b7214ec411900183d550650915d
|
| 3 |
+
size 774540736
|
weights/milfer_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "milfer",
|
| 3 |
+
"feature_sample_rate": 16000,
|
| 4 |
+
"target_sample_rate": 48000,
|
| 5 |
+
"feature_dim": 1024,
|
| 6 |
+
"checkpoint": "milfer_lora100h_step001000",
|
| 7 |
+
"decoder": {
|
| 8 |
+
"input_channel": 1024,
|
| 9 |
+
"channels": 1536,
|
| 10 |
+
"rates": [
|
| 11 |
+
8,
|
| 12 |
+
5,
|
| 13 |
+
4,
|
| 14 |
+
3,
|
| 15 |
+
2
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
}
|