File size: 7,661 Bytes
423bed8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
BarVox Audio Processing API - Audio Preprocessing Functions
"""

import logging
import numpy as np
import torch
import torchaudio.transforms as T
from scipy.signal import butter, filtfilt
from typing import Optional

logger = logging.getLogger(__name__)

def butter_highpass(cutoff, fs, order=5):
    """Create a Butterworth high-pass filter."""
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    return b, a

def butter_bandpass(lowcut, highcut, fs, order=5):
    """Create a Butterworth band-pass filter."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def highpass_filter(data, cutoff=30, fs=16000, order=2):
    """Apply a high-pass filter to remove low-frequency noise."""
    b, a = butter_highpass(cutoff, fs, order=order)
    y = filtfilt(b, a, data).copy()
    return y

def rms_normalize(waveform_np, target_dbfs=-20.0):
    """Normalize waveform to target dBFS level."""
    rms = np.sqrt(np.mean(np.square(waveform_np)))
    if rms == 0:
        return waveform_np
    scalar = (10 ** (target_dbfs / 20)) / rms
    return waveform_np * scalar

def resample_to_16khz_mono(waveform: torch.Tensor, orig_sample_rate: int) -> torch.Tensor:
    """Resample audio to 16kHz mono."""
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    if orig_sample_rate != 16000:
        resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    
    return waveform

def apply_noise_reduction(waveform_np, sample_rate=16000):
    """Apply noise reduction."""
    waveform_np = waveform_np - np.mean(waveform_np)
    waveform_np = highpass_filter(waveform_np, cutoff=30, fs=sample_rate, order=2)
    return waveform_np

def apply_normalization(waveform_np, target_dbfs=-20.0):
    """Apply RMS normalization."""
    return rms_normalize(waveform_np, target_dbfs=target_dbfs)

def apply_dynamic_compression(waveform_np, threshold_ratio=0.5, ratio=3.0):
    """Apply dynamic range compression."""
    rms = np.sqrt(np.mean(np.square(waveform_np)))
    threshold = rms * threshold_ratio
    
    compressed = np.copy(waveform_np)
    mask = np.abs(waveform_np) > threshold
    
    compressed[mask] = np.sign(waveform_np[mask]) * (
        threshold + (np.abs(waveform_np[mask]) - threshold) / ratio
    )
    
    gain = np.max(np.abs(waveform_np)) / np.max(np.abs(compressed)) if np.max(np.abs(compressed)) > 0 else 1.0
    compressed = compressed * min(gain, 2.0)
    
    logger.info(f"Applied dynamic compression (threshold_ratio={threshold_ratio}, ratio={ratio})")
    return compressed

def apply_transient_enhancement(waveform_np, sample_rate=16000, attack_boost=1.5):
    """Enhance transients."""
    window_size = int(sample_rate * 0.010)
    envelope = np.zeros_like(waveform_np)
    
    for i in range(len(waveform_np)):
        start = max(0, i - window_size // 2)
        end = min(len(waveform_np), i + window_size // 2)
        envelope[i] = np.sqrt(np.mean(np.square(waveform_np[start:end])))
    
    envelope_diff = np.diff(envelope, prepend=envelope[0])
    transient_mask = envelope_diff > (np.std(envelope_diff) * 0.5)
    
    boost = np.ones_like(waveform_np)
    boost[transient_mask] = attack_boost
    
    from scipy.ndimage import gaussian_filter1d
    boost_smooth = gaussian_filter1d(boost, sigma=window_size / 4)
    
    enhanced = waveform_np * boost_smooth
    
    max_val = np.max(np.abs(enhanced))
    if max_val > 0:
        enhanced = enhanced / max_val * np.max(np.abs(waveform_np))
    
    logger.info(f"Applied transient enhancement (boost={attack_boost})")
    return enhanced

def apply_high_frequency_boost(waveform_np, sample_rate=16000, boost_db=6.0):
    """Apply high-frequency boost (EQ)."""
    b, a = butter_bandpass(2000, 6000, sample_rate, order=4)
    presence_band = filtfilt(b, a, waveform_np)
    
    boost_factor = 10 ** (boost_db / 20.0)
    boosted = waveform_np + presence_band * (boost_factor - 1.0)
    
    max_val = np.max(np.abs(boosted))
    if max_val > 1.0:
        boosted = boosted / max_val
    
    logger.info(f"Applied high-frequency boost (+{boost_db} dB in 2-6 kHz)")
    return boosted

def apply_silero_vad(
    waveform: torch.Tensor,
    sample_rate: int = 16000,
    threshold: float = 0.35,
    min_speech_duration_ms: int = 60,
    min_silence_duration_ms: int = 300,
    padding_before_ms: int = 180,
    padding_after_ms: int = 900,
    max_speech_duration_s: float = 0,
    chunk_selection: str = 'longest'
) -> Optional[torch.Tensor]:
    """Apply Silero VAD with support for 'first', 'longest', or 'last' chunk selection."""
    try:
        from model_loader import get_models
        models = get_models()
        model = models['silero_vad']
        utils = models['silero_utils']
        (get_speech_timestamps, _, _, _, _) = utils
        
        if isinstance(waveform, torch.Tensor):
            waveform_np = waveform.squeeze().cpu().numpy()
        else:
            waveform_np = waveform
        
        total_duration_ms = (len(waveform_np) / sample_rate) * 1000
        
        speech_timestamps = get_speech_timestamps(
            waveform_np, # This was the input for 'x'
            model,
            sampling_rate=sample_rate, # This was the input for 'sr' or 'r'
            threshold=threshold,
            min_speech_duration_ms=min_speech_duration_ms,
            min_silence_duration_ms=min_silence_duration_ms,
            speech_pad_ms=0,
            max_speech_duration_s=max_speech_duration_s if max_speech_duration_s > 0 else float('inf'),
            return_seconds=False
        )
        
        if not speech_timestamps:
            logger.warning("No speech detected by Silero VAD")
            return None
        
        # Original logic was:
        # if chunk_selection == 'first':
        #     selected_chunk = speech_timestamps[0]
        # else:
        #     selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
        
        # Updated to include 'last' option
        if chunk_selection == 'first':
            selected_chunk = speech_timestamps[0]
        elif chunk_selection == 'last':
            selected_chunk = speech_timestamps[-1]
        else:  # 'longest' (default)
            selected_chunk = max(speech_timestamps, key=lambda x: x['end'] - x['start'])
        
        vad_start_ms = (selected_chunk['start'] / sample_rate) * 1000
        vad_end_ms = (selected_chunk['end'] / sample_rate) * 1000
        
        pad_before_samples = int((padding_before_ms / 1000) * sample_rate)
        pad_after_samples = int((padding_after_ms / 1000) * sample_rate)
        
        start_idx = max(0, selected_chunk['start'] - pad_before_samples)
        end_idx = min(len(waveform_np), selected_chunk['end'] + pad_after_samples)
        
        trimmed_waveform_np = waveform_np.copy()[start_idx:end_idx]
        trimmed_tensor = torch.from_numpy(trimmed_waveform_np).float()
        
        final_start_ms = (start_idx / sample_rate) * 1000
        final_end_ms = (end_idx / sample_rate) * 1000
        final_duration_ms = final_end_ms - final_start_ms
        
        logger.info(f"VAD: Original={total_duration_ms:.0f}ms | Speech=[{vad_start_ms:.0f}-{vad_end_ms:.0f}]ms | Final=[{final_start_ms:.0f}-{final_end_ms:.0f}]ms ({final_duration_ms:.0f}ms) | Selection={chunk_selection}")
        
        return trimmed_tensor
        
    except Exception as e:
        logger.error(f"Error in Silero VAD: {e}")
        return None