barvox-backend / audio_processing.py
RonenShilchikov
Restructure: move Python backend into backend/ directory
423bed8
"""
BarVox Audio Processing API - Audio Preprocessing Functions
"""
import logging
import numpy as np
import torch
import torchaudio.transforms as T
from scipy.signal import butter, filtfilt
from typing import Optional
logger = logging.getLogger(__name__)
def butter_highpass(cutoff, fs, order=5):
"""Create a Butterworth high-pass filter."""
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='high', analog=False)
return b, a
def butter_bandpass(lowcut, highcut, fs, order=5):
"""Create a Butterworth band-pass filter."""
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def highpass_filter(data, cutoff=30, fs=16000, order=2):
"""Apply a high-pass filter to remove low-frequency noise."""
b, a = butter_highpass(cutoff, fs, order=order)
y = filtfilt(b, a, data).copy()
return y
def rms_normalize(waveform_np, target_dbfs=-20.0):
"""Normalize waveform to target dBFS level."""
rms = np.sqrt(np.mean(np.square(waveform_np)))
if rms == 0:
return waveform_np
scalar = (10 ** (target_dbfs / 20)) / rms
return waveform_np * scalar
def resample_to_16khz_mono(waveform: torch.Tensor, orig_sample_rate: int) -> torch.Tensor:
"""Resample audio to 16kHz mono."""
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if orig_sample_rate != 16000:
resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform
def apply_noise_reduction(waveform_np, sample_rate=16000):
"""Apply noise reduction."""
waveform_np = waveform_np - np.mean(waveform_np)
waveform_np = highpass_filter(waveform_np, cutoff=30, fs=sample_rate, order=2)
return waveform_np
def apply_normalization(waveform_np, target_dbfs=-20.0):
"""Apply RMS normalization."""
return rms_normalize(waveform_np, target_dbfs=target_dbfs)
def apply_dynamic_compression(waveform_np, threshold_ratio=0.5, ratio=3.0):
"""Apply dynamic range compression."""
rms = np.sqrt(np.mean(np.square(waveform_np)))
threshold = rms * threshold_ratio
compressed = np.copy(waveform_np)
mask = np.abs(waveform_np) > threshold
compressed[mask] = np.sign(waveform_np[mask]) * (
threshold + (np.abs(waveform_np[mask]) - threshold) / ratio
)
gain = np.max(np.abs(waveform_np)) / np.max(np.abs(compressed)) if np.max(np.abs(compressed)) > 0 else 1.0
compressed = compressed * min(gain, 2.0)
logger.info(f"Applied dynamic compression (threshold_ratio={threshold_ratio}, ratio={ratio})")
return compressed
def apply_transient_enhancement(waveform_np, sample_rate=16000, attack_boost=1.5):
"""Enhance transients."""
window_size = int(sample_rate * 0.010)
envelope = np.zeros_like(waveform_np)
for i in range(len(waveform_np)):
start = max(0, i - window_size // 2)
end = min(len(waveform_np), i + window_size // 2)
envelope[i] = np.sqrt(np.mean(np.square(waveform_np[start:end])))
envelope_diff = np.diff(envelope, prepend=envelope[0])
transient_mask = envelope_diff > (np.std(envelope_diff) * 0.5)
boost = np.ones_like(waveform_np)
boost[transient_mask] = attack_boost
from scipy.ndimage import gaussian_filter1d
boost_smooth = gaussian_filter1d(boost, sigma=window_size / 4)
enhanced = waveform_np * boost_smooth
max_val = np.max(np.abs(enhanced))
if max_val > 0:
enhanced = enhanced / max_val * np.max(np.abs(waveform_np))
logger.info(f"Applied transient enhancement (boost={attack_boost})")
return enhanced
def apply_high_frequency_boost(waveform_np, sample_rate=16000, boost_db=6.0):
"""Apply high-frequency boost (EQ)."""
b, a = butter_bandpass(2000, 6000, sample_rate, order=4)
presence_band = filtfilt(b, a, waveform_np)
boost_factor = 10 ** (boost_db / 20.0)
boosted = waveform_np + presence_band * (boost_factor - 1.0)
max_val = np.max(np.abs(boosted))
if max_val > 1.0:
boosted = boosted / max_val
logger.info(f"Applied high-frequency boost (+{boost_db} dB in 2-6 kHz)")
return boosted
def apply_silero_vad(
waveform: torch.Tensor,
sample_rate: int = 16000,
threshold: float = 0.35,
min_speech_duration_ms: int = 60,
min_silence_duration_ms: int = 300,
padding_before_ms: int = 180,
padding_after_ms: int = 900,
max_speech_duration_s: float = 0,
chunk_selection: str = 'longest'
) -> Optional[torch.Tensor]:
"""Apply Silero VAD with support for 'first', 'longest', or 'last' chunk selection."""
try:
from model_loader import get_models
models = get_models()
model = models['silero_vad']
utils = models['silero_utils']
(get_speech_timestamps, _, _, _, _) = utils
if isinstance(waveform, torch.Tensor):
waveform_np = waveform.squeeze().cpu().numpy()
else:
waveform_np = waveform
total_duration_ms = (len(waveform_np) / sample_rate) * 1000
speech_timestamps = get_speech_timestamps(
waveform_np, # This was the input for 'x'
model,
sampling_rate=sample_rate, # This was the input for 'sr' or 'r'
threshold=threshold,
min_speech_duration_ms=min_speech_duration_ms,
min_silence_duration_ms=min_silence_duration_ms,
speech_pad_ms=0,
max_speech_duration_s=max_speech_duration_s if max_speech_duration_s > 0 else float('inf'),
return_seconds=False
)
if not speech_timestamps:
logger.warning("No speech detected by Silero VAD")
return None
# Original logic was:
# if chunk_selection == 'first':
# selected_chunk = speech_timestamps[0]
# else:
# selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
# Updated to include 'last' option
if chunk_selection == 'first':
selected_chunk = speech_timestamps[0]
elif chunk_selection == 'last':
selected_chunk = speech_timestamps[-1]
else: # 'longest' (default)
selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
vad_start_ms = (selected_chunk['start'] / sample_rate) * 1000
vad_end_ms = (selected_chunk['end'] / sample_rate) * 1000
pad_before_samples = int((padding_before_ms / 1000) * sample_rate)
pad_after_samples = int((padding_after_ms / 1000) * sample_rate)
start_idx = max(0, selected_chunk['start'] - pad_before_samples)
end_idx = min(len(waveform_np), selected_chunk['end'] + pad_after_samples)
trimmed_waveform_np = waveform_np.copy()[start_idx:end_idx]
trimmed_tensor = torch.from_numpy(trimmed_waveform_np).float()
final_start_ms = (start_idx / sample_rate) * 1000
final_end_ms = (end_idx / sample_rate) * 1000
final_duration_ms = final_end_ms - final_start_ms
logger.info(f"VAD: Original={total_duration_ms:.0f}ms | Speech=[{vad_start_ms:.0f}-{vad_end_ms:.0f}]ms | Final=[{final_start_ms:.0f}-{final_end_ms:.0f}]ms ({final_duration_ms:.0f}ms) | Selection={chunk_selection}")
return trimmed_tensor
except Exception as e:
logger.error(f"Error in Silero VAD: {e}")
return None