DeepFilterNet3 / app.py
humairmunirawn's picture
Update app.py
b70f1bc verified
import glob
import math
import os
import sys
import tempfile
import time
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
import soundfile as sf
from loguru import logger
from PIL import Image
from torch import Tensor
from scipy import signal
import torch
import torchaudio
print(torch.__version__)
from torchaudio.compliance.kaldi import resample_waveform
print("Kaldi module loaded successfully!")
# Mock torchaudio.backend.common.AudioMetaData for df package compatibility
class MockAudioMetaData:
"""Mock AudioMetaData to satisfy df package imports"""
def __init__(self, sample_rate, num_frames, num_channels, bits_per_sample, encoding):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding
# Create mock torchaudio module
class MockTorchaudio:
class backend:
class common:
AudioMetaData = MockAudioMetaData
sys.modules['torchaudio'] = MockTorchaudio()
sys.modules['torchaudio.backend'] = MockTorchaudio.backend()
sys.modules['torchaudio.backend.common'] = MockTorchaudio.backend.common()
# Now import df package
from df import config
from df.enhance import enhance, init_df
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, df, _ = init_df("./DeepFilterNet2", config_allow_defaults=True)
model = model.to(device=device).eval()
fig_noisy: plt.Figure
fig_enh: plt.Figure
ax_noisy: plt.Axes
ax_enh: plt.Axes
fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4))
fig_noisy.set_tight_layout(True)
fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4))
fig_enh.set_tight_layout(True)
NOISES = {
"None": None,
"Kitchen": "samples/dkitchen.wav",
"Living Room": "samples/dliving.wav",
"River": "samples/nriver.wav",
"Cafe": "samples/scafe.wav",
}
@dataclass
class AudioMetaData:
"""Simple audio metadata container to replace torchaudio.backend.common.AudioMetaData"""
sample_rate: int
num_frames: int
num_channels: int
bits_per_sample: int
encoding: str
def load_audio(file_path: str, sr: int) -> Tuple[Tensor, AudioMetaData]:
"""Load audio file using soundfile and resample if necessary.
Args:
file_path: Path to audio file
sr: Target sample rate
Returns:
audio: Torch tensor of shape [channels, samples]
meta: AudioMetaData with file information
"""
try:
# Read audio using soundfile
audio_np, sample_rate = sf.read(file_path, dtype='float32')
# Handle mono/stereo
if audio_np.ndim == 1:
audio_np = audio_np[np.newaxis, :] # Add channel dimension
num_channels = 1
else:
audio_np = audio_np.T # Convert [samples, channels] to [channels, samples]
num_channels = audio_np.shape[0]
# Get file info for metadata
info = sf.info(file_path)
num_frames = info.frames
# Create metadata
meta = AudioMetaData(
sample_rate=sample_rate,
num_frames=num_frames,
num_channels=num_channels,
bits_per_sample=-1, # Not directly available from soundfile
encoding=info.format
)
# Convert to torch tensor
audio = torch.from_numpy(audio_np).float()
# Resample if necessary
if sample_rate != sr:
audio = resample_audio(audio, sample_rate, sr)
meta.sample_rate = sr
return audio, meta
except Exception as e:
logger.error(f"Error loading audio file {file_path}: {e}")
raise
def save_audio(file_path: str, audio: Tensor, sr: int) -> None:
"""Save audio tensor to file using soundfile.
Args:
file_path: Output file path
audio: Audio tensor of shape [channels, samples] or [samples]
sr: Sample rate
"""
try:
# Convert tensor to numpy
audio_np = audio.cpu().numpy()
# Handle tensor shape
if audio_np.ndim == 3:
audio_np = audio_np.squeeze(0)
# Convert [channels, samples] to [samples, channels] for soundfile
if audio_np.ndim == 2:
audio_np = audio_np.T
# Ensure float32
audio_np = audio_np.astype(np.float32)
# Clip to valid range
audio_np = np.clip(audio_np, -1.0, 1.0)
# Save using soundfile
sf.write(file_path, audio_np, sr)
logger.info(f"Saved audio to {file_path}")
except Exception as e:
logger.error(f"Error saving audio to {file_path}: {e}")
raise
def resample_audio(audio: Tensor, sr_orig: int, sr_target: int) -> Tensor:
"""Resample audio using scipy.signal.resample_poly.
Args:
audio: Audio tensor of shape [channels, samples]
sr_orig: Original sample rate
sr_target: Target sample rate
Returns:
Resampled audio tensor
"""
if sr_orig == sr_target:
return audio
# Convert to numpy for resampling
audio_np = audio.cpu().numpy()
# Calculate gcd for polyphase resampling
from math import gcd
g = gcd(sr_orig, sr_target)
up = sr_target // g
down = sr_orig // g
logger.debug(f"Resampling from {sr_orig} to {sr_target} (up={up}, down={down})")
# Resample each channel
if audio_np.ndim == 2:
resampled = np.zeros((audio_np.shape[0], int(audio_np.shape[1] * sr_target / sr_orig)))
for ch in range(audio_np.shape[0]):
resampled[ch] = signal.resample_poly(audio_np[ch], up, down)
else:
resampled = signal.resample_poly(audio_np, up, down)
return torch.from_numpy(resampled).float()
def mix_at_snr(clean, noise, snr, eps=1e-10):
"""Mix clean and noise signal at a given SNR.
Args:
clean: 1D Tensor with the clean signal to mix.
noise: 1D Tensor of shape.
snr: Signal to noise ratio.
Returns:
clean: 1D Tensor with gain changed according to the snr.
noise: 1D Tensor with the combined noise channels.
mix: 1D Tensor with added clean and noise signals.
"""
clean = torch.as_tensor(clean).mean(0, keepdim=True)
noise = torch.as_tensor(noise).mean(0, keepdim=True)
if noise.shape[1] < clean.shape[1]:
noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
max_start = int(noise.shape[1] - clean.shape[1])
start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
logger.debug(f"start: {start}, {clean.shape}")
noise = noise[:, start : start + clean.shape[1]]
E_speech = torch.mean(clean.pow(2)) + eps
E_noise = torch.mean(noise.pow(2))
K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
noise = noise / K
mixture = clean + noise
logger.debug(f"mixture: {mixture.shape}")
assert torch.isfinite(mixture).all()
max_m = mixture.abs().max()
if max_m > 1:
logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m}")
clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
return clean, noise, mixture
def load_audio_gradio(
audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
) -> Optional[Tuple[Tensor, AudioMetaData]]:
"""Load audio from file or gradio microphone input.
Args:
audio_or_file: Path to audio file, tuple from gradio mic, or None
sr: Target sample rate
Returns:
Tuple of (audio tensor, metadata) or None
"""
if audio_or_file is None:
return None
if isinstance(audio_or_file, str):
if audio_or_file.lower() == "none":
return None
# Load from file path
audio, meta = load_audio(audio_or_file, sr)
else:
# Handle gradio microphone input
meta = AudioMetaData(
sample_rate=-1,
num_frames=-1,
num_channels=-1,
bits_per_sample=-1,
encoding=""
)
assert isinstance(audio_or_file, (tuple, list))
sample_rate, audio_np = audio_or_file
# Gradio returns [samples, channels], reshape if needed
audio_np = audio_np.reshape(audio_np.shape[0], -1).T
# Handle different integer formats
if audio_np.dtype == np.int16:
audio_np = (audio_np / (1 << 15)).astype(np.float32)
elif audio_np.dtype == np.int32:
audio_np = (audio_np / (1 << 31)).astype(np.float32)
audio = torch.from_numpy(audio_np).float()
# Resample if necessary
if sample_rate != sr:
audio = resample_audio(audio, sample_rate, sr)
meta.sample_rate = sr
return audio, meta
def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
"""Main demo function for audio denoising.
Args:
speech_upl: Path to uploaded speech file
noise_type: Type of noise to add
snr: Signal-to-noise ratio
mic_input: Path to microphone input file
Returns:
Tuple of (noisy_audio_path, noisy_spectrogram, enhanced_audio_path, enhanced_spectrogram)
"""
if mic_input:
speech_upl = mic_input
sr = config("sr", 48000, int, section="df")
logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
snr = int(snr)
noise_fn = NOISES[noise_type]
meta = AudioMetaData(-1, -1, -1, -1, "")
max_s = 10 # limit to 10 seconds
if speech_upl is not None:
sample, meta = load_audio(speech_upl, sr)
max_len = max_s * sr
if sample.shape[-1] > max_len:
start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
sample = sample[..., start : start + max_len]
else:
sample, meta = load_audio("samples/p232_013_clean.wav", sr)
sample = sample[..., : max_s * sr]
if sample.dim() > 1 and sample.shape[0] > 1:
assert (
sample.shape[1] > sample.shape[0]
), f"Expecting channels first, but got {sample.shape}"
sample = sample.mean(dim=0, keepdim=True)
logger.info(f"Loaded sample with shape {sample.shape}")
if noise_fn is not None:
noise, _ = load_audio(noise_fn, sr)
logger.info(f"Loaded noise with shape {noise.shape}")
_, _, sample = mix_at_snr(sample, noise, snr)
logger.info("Start denoising audio")
enhanced = enhance(model, df, sample)
logger.info("Denoising finished")
# Apply fade-in limiter
lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
enhanced = enhanced * lim
# Resample back to original sample rate if needed
if meta.sample_rate != sr:
enhanced = resample_audio(enhanced, sr, meta.sample_rate)
sample = resample_audio(sample, sr, meta.sample_rate)
sr = meta.sample_rate
# Save audio files
noisy_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
save_audio(noisy_wav, sample, sr)
enhanced_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
save_audio(enhanced_wav, enhanced, sr)
logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
# Generate spectrograms
ax_noisy.clear()
ax_enh.clear()
noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
# Cleanup temporary files (except the ones we want to return)
filter = [speech_upl, noisy_wav, enhanced_wav]
if mic_input is not None and mic_input != "":
filter.append(mic_input)
cleanup_tmp(filter)
return noisy_wav, noisy_im, enhanced_wav, enh_im
def specshow(
spec,
ax=None,
title=None,
xlabel=None,
ylabel=None,
sr=48000,
n_fft=None,
hop=None,
t=None,
f=None,
vmin=-100,
vmax=0,
xlim=None,
ylim=None,
cmap="inferno",
):
"""Plots a spectrogram of shape [F, T]"""
spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
if ax is not None:
set_title = ax.set_title
set_xlabel = ax.set_xlabel
set_ylabel = ax.set_ylabel
set_xlim = ax.set_xlim
set_ylim = ax.set_ylim
else:
ax = plt
set_title = plt.title
set_xlabel = plt.xlabel
set_ylabel = plt.ylabel
set_xlim = plt.xlim
set_ylim = plt.ylim
if n_fft is None:
if spec.shape[0] % 2 == 0:
n_fft = spec.shape[0] * 2
else:
n_fft = (spec.shape[0] - 1) * 2
hop = hop or n_fft // 4
if t is None:
t = np.arange(0, spec_np.shape[-1]) * hop / sr
if f is None:
f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
im = ax.pcolormesh(
t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
)
if title is not None:
set_title(title)
if xlabel is not None:
set_xlabel(xlabel)
if ylabel is not None:
set_ylabel(ylabel)
if xlim is not None:
set_xlim(xlim)
if ylim is not None:
set_ylim(ylim)
return im
def spec_im(
audio: torch.Tensor,
figsize=(15, 5),
colorbar=False,
colorbar_format=None,
figure=None,
labels=True,
**kwargs,
) -> Image:
"""Convert audio to spectrogram image.
Args:
audio: Audio tensor
figsize: Figure size
colorbar: Whether to show colorbar
colorbar_format: Format for colorbar
figure: Matplotlib figure to use
labels: Whether to show axis labels
**kwargs: Additional arguments for specshow
Returns:
PIL Image of the spectrogram
"""
audio = torch.as_tensor(audio)
if labels:
kwargs.setdefault("xlabel", "Time [s]")
kwargs.setdefault("ylabel", "Frequency [Hz]")
n_fft = kwargs.setdefault("n_fft", 1024)
hop = kwargs.setdefault("hop", 512)
w = torch.hann_window(n_fft, device=audio.device)
spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
spec = spec.div_(w.pow(2).sum())
spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
kwargs.setdefault("vmax", max(0.0, spec.max().item()))
if figure is None:
figure = plt.figure(figsize=figsize)
figure.set_tight_layout(True)
if spec.dim() > 2:
spec = spec.squeeze(0)
im = specshow(spec, **kwargs)
if colorbar:
ckwargs = {}
if "ax" in kwargs:
if colorbar_format is None:
if kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None:
colorbar_format = "%+2.0f dB"
ckwargs = {"ax": kwargs["ax"]}
plt.colorbar(im, format=colorbar_format, **ckwargs)
figure.canvas.draw()
return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
def cleanup_tmp(filter: List[str] = [], hours_keep=2):
"""Clean up old temporary files.
Args:
filter: List of file paths to keep (not delete)
hours_keep: Number of hours to keep files
"""
filter.append("p232")
logger.info(f"Filter: {filter}")
# Cleanup some old wav files
if os.path.exists("/tmp"):
for f in glob.glob("/tmp/*"):
print(f"Got file {f}")
is_old = (time.time() - os.path.getmtime(f)) / 3600 > hours_keep
filtered = any(filt in f for filt in filter if filt is not None)
if is_old and not filtered:
try:
os.remove(f)
logger.info(f"Removed file {f}")
except Exception as e:
logger.warning(f"failed to remove file {f}: {e}")
def toggle(choice):
"""Toggle between microphone and file input.
Args:
choice: "mic" or "file"
Returns:
Tuple of updated components visibility
"""
if choice == "mic":
return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
else:
return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
# Create Gradio interface
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
## DeepFilterNet2 Demo
This demo denoises audio files using DeepFilterNet. Try it with your own voice!
"""
)
with gr.Row():
with gr.Column():
radio = gr.Radio(
["mic", "file"], value="file", label="How would you like to upload your audio?"
)
mic_input = gr.Mic(label="Input", type="filepath", visible=False)
audio_file = gr.Audio(type="filepath", label="Input", visible=True)
inputs = [
audio_file,
gr.Dropdown(
label="Add background noise",
choices=list(NOISES.keys()),
value="None",
),
gr.Dropdown(
label="Noise Level (SNR)",
choices=["-5", "0", "10", "20"],
value="10",
),
mic_input,
]
btn = gr.Button("Generate")
with gr.Column():
outputs = [
gr.Audio(type="filepath", label="Noisy audio"),
gr.Image(label="Noisy spectrogram"),
gr.Audio(type="filepath", label="Enhanced audio"),
gr.Image(label="Enhanced spectrogram"),
]
btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name='denoise')
radio.change(toggle, radio, [mic_input, audio_file])
gr.Examples(
[
["./samples/p232_013_clean.wav", "Kitchen", "10"],
["./samples/p232_013_clean.wav", "Cafe", "10"],
["./samples/p232_019_clean.wav", "Cafe", "10"],
["./samples/p232_019_clean.wav", "River", "10"],
],
fn=demo_fn,
inputs=inputs,
outputs=outputs,
cache_examples=True,
)
gr.Markdown(open("usage.md").read())
cleanup_tmp()
demo.launch(enable_queue=True)