File size: 5,973 Bytes
a361db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Target Enhancement Module - Audio Enhancement & Denoising for TOI

Applies advanced signal processing to enhance the talker of interest:
1. Spectral subtraction for noise reduction
2. Wiener filtering for preservation of speech characteristics
3. Temporal smoothing and artifact removal
"""

import numpy as np
import soundfile as sf
from scipy import signal


def apply_spectral_subtraction(audio, sr, noise_estimate_factor=0.5):
    """
    Spectral subtraction: Subtract estimated noise from speech spectrum.

    Args:
        audio: Audio time series
        sr: Sample rate
        noise_estimate_factor: Factor for noise power estimation (0-1)

    Returns:
        Enhanced audio
    """
    # Compute STFT
    nperseg = min(2048, sr // 10)  # Window size ~100ms
    noverlap = nperseg // 2
    f, t, Sxx = signal.spectrogram(audio, sr, nperseg=nperseg, noverlap=noverlap)

    # Estimate noise power (assume silent frames at start)
    n_silence_frames = max(1, int(0.5 * sr / (nperseg - noverlap)))  # ~500ms
    noise_power = np.mean(Sxx[:, :n_silence_frames], axis=1, keepdims=True)

    # Spectral subtraction
    Sxx_enhanced = Sxx - noise_estimate_factor * noise_power
    Sxx_enhanced = np.maximum(Sxx_enhanced, 0.1 * Sxx)  # Prevent over-subtraction

    # Reconstruct
    _, enhanced = signal.istft(np.sqrt(Sxx_enhanced), sr, nperseg=nperseg, noverlap=noverlap)

    return enhanced[:len(audio)]


def apply_wiener_filtering(audio, sr, frame_length_ms=20):
    """
    Wiener filtering: Minimize MSE between noisy and clean speech.
    Approximated using adaptive filtering on frames.

    Args:
        audio: Audio time series
        sr: Sample rate
        frame_length_ms: Frame length in milliseconds

    Returns:
        Enhanced audio
    """
    frame_len = int(sr * frame_length_ms / 1000)
    hop_len = frame_len // 2

    # Compute short-time energy for voice activity detection
    energy = np.array([
        np.sum(audio[i:i+frame_len]**2)
        for i in range(0, len(audio) - frame_len, hop_len)
    ])

    # Threshold for voice activity
    energy_threshold = np.percentile(energy, 25)
    speech_activity = energy > energy_threshold

    # Apply mild Wiener-like filtering
    enhanced = audio.copy()
    for i, (start, end) in enumerate([(j, j+frame_len)
                                       for j in range(0, len(audio) - frame_len, hop_len)]):
        frame = audio[start:end]

        if speech_activity[i]:
            # Preserve speech frames (minimal filtering)
            smoothing_factor = 0.1
        else:
            # Attenuate non-speech frames
            smoothing_factor = 0.5

        smoothed_frame = signal.savgol_filter(frame, window_length=min(11, len(frame)|1),
                                             polyorder=3, mode='nearest')
        enhanced[start:end] = (1 - smoothing_factor) * frame + smoothing_factor * smoothed_frame

    return enhanced


def apply_temporal_smoothing(audio, sr, window_ms=5):
    """
    Apply temporal smoothing to reduce artifacts and clicks.
    """
    window_len = max(3, int(sr * window_ms / 1000) | 1)  # Ensure odd
    return signal.savgol_filter(audio, window_length=window_len, polyorder=2, mode='nearest')


def enhance_target_speaker(input_file, output_file, enhancement_level='medium'):
    """
    Main enhancement pipeline for talker of interest.

    Args:
        input_file: Path to source audio
        output_file: Path to save enhanced audio
        enhancement_level: 'light', 'medium', or 'heavy'
    """
    # Load audio
    audio, sr = sf.read(str(input_file))
    original_length = len(audio)

    # Ensure mono
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)

    # Normalize to [-1, 1]
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val

    # Apply enhancement pipeline
    if enhancement_level in ['medium', 'heavy']:
        # Spectral subtraction
        noise_factor = 0.3 if enhancement_level == 'medium' else 0.5
        audio = apply_spectral_subtraction(audio, sr, noise_estimate_factor=noise_factor)

    if enhancement_level in ['light', 'medium']:
        # Wiener filtering
        audio = apply_wiener_filtering(audio, sr, frame_length_ms=20)
    elif enhancement_level == 'heavy':
        # Stronger filtering
        audio = apply_wiener_filtering(audio, sr, frame_length_ms=10)

    # Temporal smoothing
    smoothing_ms = 3 if enhancement_level == 'light' else 5
    audio = apply_temporal_smoothing(audio, sr, window_ms=smoothing_ms)

    # Prevent clipping with gentle compression
    audio = np.tanh(audio * 0.95)

    # Normalize output
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = 0.95 * audio / max_val  # Leave headroom

    # Ensure correct length
    audio = audio[:original_length]

    # Save enhanced audio
    sf.write(str(output_file), audio, sr, subtype='PCM_16')

    return {
        'input_file': str(input_file),
        'output_file': str(output_file),
        'enhancement_level': enhancement_level,
        'sample_rate': sr,
        'duration_seconds': len(audio) / sr,
        'methods_applied': [
            'spectral_subtraction' if enhancement_level in ['medium', 'heavy'] else None,
            'wiener_filtering',
            'temporal_smoothing',
            'soft_clipping'
        ]
    }


if __name__ == '__main__':
    import sys

    if len(sys.argv) < 2:
        print("Usage: python enhance_target.py <input_wav> [output_wav] [level]")
        print("  level: 'light' (default), 'medium', or 'heavy'")
        sys.exit(1)

    input_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) > 2 else input_file.replace('.wav', '_enhanced.wav')
    level = sys.argv[3] if len(sys.argv) > 3 else 'medium'

    result = enhance_target_speaker(input_file, output_file, enhancement_level=level)
    print("Enhancement complete:")
    for key, value in result.items():
        print(f"  {key}: {value}")