|
|
import argparse |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import librosa |
|
|
from tflite_runtime.interpreter import Interpreter |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
TFLITE_DIR = Path('./') |
|
|
|
|
|
|
|
|
WIN_LEN = 320 |
|
|
HOP_SIZE = WIN_LEN // 2 |
|
|
|
|
|
|
|
|
def vorbis_window(window_len: int) -> np.ndarray: |
|
|
window_size_h = window_len / 2 |
|
|
indices = np.arange(window_len) |
|
|
sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h) |
|
|
window = np.sin(0.5 * np.pi * sin * sin) |
|
|
return window.astype(np.float32) |
|
|
|
|
|
|
|
|
def get_wnorm(window_len: int, frame_size: int) -> float: |
|
|
|
|
|
return 1.0 / (window_len ** 2 / (2 * frame_size)) |
|
|
|
|
|
|
|
|
|
|
|
_WIN = vorbis_window(WIN_LEN) |
|
|
_WNORM = get_wnorm(WIN_LEN, HOP_SIZE) |
|
|
|
|
|
|
|
|
def preprocessing(waveform_16k: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
waveform_16k: 1D float32 numpy array at 16 kHz, mono, range ~[-1,1] |
|
|
Returns complex STFT as real/imag split: [B=1, T, F, 2] float32 |
|
|
""" |
|
|
|
|
|
spec = librosa.stft( |
|
|
y=waveform_16k.astype(np.float32, copy=False), |
|
|
n_fft=WIN_LEN, |
|
|
hop_length=HOP_SIZE, |
|
|
win_length=WIN_LEN, |
|
|
window=_WIN, |
|
|
center=False, |
|
|
pad_mode="reflect" |
|
|
) |
|
|
spec = (spec.T * _WNORM).astype(np.complex64) |
|
|
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32) |
|
|
return spec_ri[None, ...] |
|
|
|
|
|
|
|
|
def postprocessing(spec_e: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
spec_e: [1, T, F, 2] float32 |
|
|
Returns waveform (1D float32, 16 kHz) |
|
|
""" |
|
|
|
|
|
spec_c = spec_e[0].astype(np.float32) |
|
|
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64) |
|
|
|
|
|
waveform_e = librosa.istft( |
|
|
spec, |
|
|
hop_length=HOP_SIZE, |
|
|
win_length=WIN_LEN, |
|
|
window=_WIN, |
|
|
center=True, |
|
|
length=None, |
|
|
).astype(np.float32) |
|
|
|
|
|
waveform_e = waveform_e / _WNORM |
|
|
waveform_e = np.concatenate([waveform_e[WIN_LEN * 2:], np.zeros(WIN_LEN * 2, dtype=np.float32)]) |
|
|
return waveform_e.astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
def to_mono(audio: np.ndarray) -> np.ndarray: |
|
|
if audio.ndim == 1: |
|
|
return audio |
|
|
|
|
|
return np.mean(audio, axis=1) |
|
|
|
|
|
|
|
|
def ensure_16k(waveform: np.ndarray, sr: int, target_sr: int = 16000) -> np.ndarray: |
|
|
if sr == target_sr: |
|
|
return waveform.astype(np.float32, copy=False) |
|
|
return librosa.resample(waveform.astype(np.float32, copy=False), orig_sr=sr, target_sr=target_sr) |
|
|
|
|
|
|
|
|
def resample_back(waveform_16k: np.ndarray, target_sr: int) -> np.ndarray: |
|
|
if target_sr == 16000: |
|
|
return waveform_16k |
|
|
return librosa.resample(waveform_16k.astype(np.float32, copy=False), orig_sr=16000, target_sr=target_sr) |
|
|
|
|
|
|
|
|
def pcm16_safe(x: np.ndarray) -> np.ndarray: |
|
|
x = np.clip(x, -1.0, 1.0) |
|
|
return (x * 32767.0).astype(np.int16) |
|
|
|
|
|
|
|
|
|
|
|
def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None: |
|
|
|
|
|
audio, sr_in = sf.read(str(in_path), always_2d=False) |
|
|
audio = to_mono(audio) |
|
|
|
|
|
|
|
|
audio = audio.astype(np.float32, copy=False) |
|
|
audio_16k = ensure_16k(audio, sr_in, 16000) |
|
|
|
|
|
|
|
|
spec = preprocessing(audio_16k) |
|
|
num_frames = spec.shape[1] |
|
|
|
|
|
|
|
|
interpreter = Interpreter(model_path=str(TFLITE_DIR / (model_name + '.tflite'))) |
|
|
interpreter.allocate_tensors() |
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for t in tqdm(range(num_frames), desc=f"{in_path.name}", unit="frm", leave=False): |
|
|
frame = spec[:, t:t + 1] |
|
|
|
|
|
frame = np.ascontiguousarray(frame, dtype=np.float32) |
|
|
|
|
|
interpreter.set_tensor(input_details[0]["index"], frame) |
|
|
interpreter.invoke() |
|
|
y = interpreter.get_tensor(output_details[0]["index"]) |
|
|
outputs.append(np.ascontiguousarray(y, dtype=np.float32)) |
|
|
|
|
|
|
|
|
spec_e = np.concatenate(outputs, axis=1).astype(np.float32) |
|
|
|
|
|
|
|
|
enhanced_16k = postprocessing(spec_e) |
|
|
enhanced = resample_back(enhanced_16k, sr_in) |
|
|
|
|
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
sf.write(str(out_path), pcm16_safe(enhanced), sr_in, subtype="PCM_16") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Enhance WAV files with a DPDFNet TFLite model (streaming).") |
|
|
parser.add_argument("--noisy_dir", type=str, required=True, help="Folder with noisy *.wav files (non-recursive).") |
|
|
parser.add_argument("--enhanced_dir", type=str, required=True, help="Output folder for enhanced WAVs.") |
|
|
parser.add_argument( |
|
|
"--model_name", |
|
|
type=str, |
|
|
default="dpdfnet8", |
|
|
choices=["baseline", "dpdfnet2", "dpdfnet4", "dpdfnet8"], |
|
|
help=( |
|
|
"Name of the model to use. Options: " |
|
|
"'baseline', 'dpdfnet2', 'dpdfnet4', 'dpdfnet8'. " |
|
|
"Default is 'dpdfnet8'." |
|
|
), |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
noisy_dir = Path(args.noisy_dir) |
|
|
enhanced_dir = Path(args.enhanced_dir) |
|
|
model_name = args.model_name |
|
|
|
|
|
if not noisy_dir.is_dir(): |
|
|
print(f"ERROR: --noisy_dir does not exist or is not a directory: {noisy_dir}", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
wavs = sorted(p for p in noisy_dir.glob("*.wav") if p.is_file()) |
|
|
if not wavs: |
|
|
print(f"No .wav files found in {noisy_dir} (non-recursive).") |
|
|
sys.exit(0) |
|
|
|
|
|
print(f"Model: {model_name}") |
|
|
print(f"Input : {noisy_dir}") |
|
|
print(f"Output: {enhanced_dir}") |
|
|
print(f"Found {len(wavs)} file(s). Enhancing...\n") |
|
|
|
|
|
for wav in wavs: |
|
|
out_path = enhanced_dir / (wav.stem + f'_{model_name}.wav') |
|
|
try: |
|
|
enhance_file(wav, out_path, model_name) |
|
|
except Exception as e: |
|
|
print(f"[SKIP] {wav.name} due to error: {e}", file=sys.stderr) |
|
|
|
|
|
print("\nProcessing complete. Outputs saved in:", enhanced_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|