Spaces:
Running
Running
File size: 7,661 Bytes
423bed8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
BarVox Audio Processing API - Audio Preprocessing Functions
"""
import logging
import numpy as np
import torch
import torchaudio.transforms as T
from scipy.signal import butter, filtfilt
from typing import Optional
logger = logging.getLogger(__name__)
def butter_highpass(cutoff, fs, order=5):
"""Create a Butterworth high-pass filter."""
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = butter(order, normal_cutoff, btype='high', analog=False)
return b, a
def butter_bandpass(lowcut, highcut, fs, order=5):
"""Create a Butterworth band-pass filter."""
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return b, a
def highpass_filter(data, cutoff=30, fs=16000, order=2):
"""Apply a high-pass filter to remove low-frequency noise."""
b, a = butter_highpass(cutoff, fs, order=order)
y = filtfilt(b, a, data).copy()
return y
def rms_normalize(waveform_np, target_dbfs=-20.0):
"""Normalize waveform to target dBFS level."""
rms = np.sqrt(np.mean(np.square(waveform_np)))
if rms == 0:
return waveform_np
scalar = (10 ** (target_dbfs / 20)) / rms
return waveform_np * scalar
def resample_to_16khz_mono(waveform: torch.Tensor, orig_sample_rate: int) -> torch.Tensor:
"""Resample audio to 16kHz mono."""
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if orig_sample_rate != 16000:
resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform
def apply_noise_reduction(waveform_np, sample_rate=16000):
"""Apply noise reduction."""
waveform_np = waveform_np - np.mean(waveform_np)
waveform_np = highpass_filter(waveform_np, cutoff=30, fs=sample_rate, order=2)
return waveform_np
def apply_normalization(waveform_np, target_dbfs=-20.0):
"""Apply RMS normalization."""
return rms_normalize(waveform_np, target_dbfs=target_dbfs)
def apply_dynamic_compression(waveform_np, threshold_ratio=0.5, ratio=3.0):
"""Apply dynamic range compression."""
rms = np.sqrt(np.mean(np.square(waveform_np)))
threshold = rms * threshold_ratio
compressed = np.copy(waveform_np)
mask = np.abs(waveform_np) > threshold
compressed[mask] = np.sign(waveform_np[mask]) * (
threshold + (np.abs(waveform_np[mask]) - threshold) / ratio
)
gain = np.max(np.abs(waveform_np)) / np.max(np.abs(compressed)) if np.max(np.abs(compressed)) > 0 else 1.0
compressed = compressed * min(gain, 2.0)
logger.info(f"Applied dynamic compression (threshold_ratio={threshold_ratio}, ratio={ratio})")
return compressed
def apply_transient_enhancement(waveform_np, sample_rate=16000, attack_boost=1.5):
"""Enhance transients."""
window_size = int(sample_rate * 0.010)
envelope = np.zeros_like(waveform_np)
for i in range(len(waveform_np)):
start = max(0, i - window_size // 2)
end = min(len(waveform_np), i + window_size // 2)
envelope[i] = np.sqrt(np.mean(np.square(waveform_np[start:end])))
envelope_diff = np.diff(envelope, prepend=envelope[0])
transient_mask = envelope_diff > (np.std(envelope_diff) * 0.5)
boost = np.ones_like(waveform_np)
boost[transient_mask] = attack_boost
from scipy.ndimage import gaussian_filter1d
boost_smooth = gaussian_filter1d(boost, sigma=window_size / 4)
enhanced = waveform_np * boost_smooth
max_val = np.max(np.abs(enhanced))
if max_val > 0:
enhanced = enhanced / max_val * np.max(np.abs(waveform_np))
logger.info(f"Applied transient enhancement (boost={attack_boost})")
return enhanced
def apply_high_frequency_boost(waveform_np, sample_rate=16000, boost_db=6.0):
"""Apply high-frequency boost (EQ)."""
b, a = butter_bandpass(2000, 6000, sample_rate, order=4)
presence_band = filtfilt(b, a, waveform_np)
boost_factor = 10 ** (boost_db / 20.0)
boosted = waveform_np + presence_band * (boost_factor - 1.0)
max_val = np.max(np.abs(boosted))
if max_val > 1.0:
boosted = boosted / max_val
logger.info(f"Applied high-frequency boost (+{boost_db} dB in 2-6 kHz)")
return boosted
def apply_silero_vad(
waveform: torch.Tensor,
sample_rate: int = 16000,
threshold: float = 0.35,
min_speech_duration_ms: int = 60,
min_silence_duration_ms: int = 300,
padding_before_ms: int = 180,
padding_after_ms: int = 900,
max_speech_duration_s: float = 0,
chunk_selection: str = 'longest'
) -> Optional[torch.Tensor]:
"""Apply Silero VAD with support for 'first', 'longest', or 'last' chunk selection."""
try:
from model_loader import get_models
models = get_models()
model = models['silero_vad']
utils = models['silero_utils']
(get_speech_timestamps, _, _, _, _) = utils
if isinstance(waveform, torch.Tensor):
waveform_np = waveform.squeeze().cpu().numpy()
else:
waveform_np = waveform
total_duration_ms = (len(waveform_np) / sample_rate) * 1000
speech_timestamps = get_speech_timestamps(
waveform_np, # This was the input for 'x'
model,
sampling_rate=sample_rate, # This was the input for 'sr' or 'r'
threshold=threshold,
min_speech_duration_ms=min_speech_duration_ms,
min_silence_duration_ms=min_silence_duration_ms,
speech_pad_ms=0,
max_speech_duration_s=max_speech_duration_s if max_speech_duration_s > 0 else float('inf'),
return_seconds=False
)
if not speech_timestamps:
logger.warning("No speech detected by Silero VAD")
return None
# Original logic was:
# if chunk_selection == 'first':
# selected_chunk = speech_timestamps[0]
# else:
# selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
# Updated to include 'last' option
if chunk_selection == 'first':
selected_chunk = speech_timestamps[0]
elif chunk_selection == 'last':
selected_chunk = speech_timestamps[-1]
else: # 'longest' (default)
selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
vad_start_ms = (selected_chunk['start'] / sample_rate) * 1000
vad_end_ms = (selected_chunk['end'] / sample_rate) * 1000
pad_before_samples = int((padding_before_ms / 1000) * sample_rate)
pad_after_samples = int((padding_after_ms / 1000) * sample_rate)
start_idx = max(0, selected_chunk['start'] - pad_before_samples)
end_idx = min(len(waveform_np), selected_chunk['end'] + pad_after_samples)
trimmed_waveform_np = waveform_np.copy()[start_idx:end_idx]
trimmed_tensor = torch.from_numpy(trimmed_waveform_np).float()
final_start_ms = (start_idx / sample_rate) * 1000
final_end_ms = (end_idx / sample_rate) * 1000
final_duration_ms = final_end_ms - final_start_ms
logger.info(f"VAD: Original={total_duration_ms:.0f}ms | Speech=[{vad_start_ms:.0f}-{vad_end_ms:.0f}]ms | Final=[{final_start_ms:.0f}-{final_end_ms:.0f}]ms ({final_duration_ms:.0f}ms) | Selection={chunk_selection}")
return trimmed_tensor
except Exception as e:
logger.error(f"Error in Silero VAD: {e}")
return None |