hedrekao
HF deploy: clean snapshot without local artifacts
a361db3
"""
Target Enhancement Module - Audio Enhancement & Denoising for TOI
Applies advanced signal processing to enhance the talker of interest:
1. Spectral subtraction for noise reduction
2. Wiener filtering for preservation of speech characteristics
3. Temporal smoothing and artifact removal
"""
import numpy as np
import soundfile as sf
from scipy import signal
def apply_spectral_subtraction(audio, sr, noise_estimate_factor=0.5):
"""
Spectral subtraction: Subtract estimated noise from speech spectrum.
Args:
audio: Audio time series
sr: Sample rate
noise_estimate_factor: Factor for noise power estimation (0-1)
Returns:
Enhanced audio
"""
# Compute STFT
nperseg = min(2048, sr // 10) # Window size ~100ms
noverlap = nperseg // 2
f, t, Sxx = signal.spectrogram(audio, sr, nperseg=nperseg, noverlap=noverlap)
# Estimate noise power (assume silent frames at start)
n_silence_frames = max(1, int(0.5 * sr / (nperseg - noverlap))) # ~500ms
noise_power = np.mean(Sxx[:, :n_silence_frames], axis=1, keepdims=True)
# Spectral subtraction
Sxx_enhanced = Sxx - noise_estimate_factor * noise_power
Sxx_enhanced = np.maximum(Sxx_enhanced, 0.1 * Sxx) # Prevent over-subtraction
# Reconstruct
_, enhanced = signal.istft(np.sqrt(Sxx_enhanced), sr, nperseg=nperseg, noverlap=noverlap)
return enhanced[:len(audio)]
def apply_wiener_filtering(audio, sr, frame_length_ms=20):
"""
Wiener filtering: Minimize MSE between noisy and clean speech.
Approximated using adaptive filtering on frames.
Args:
audio: Audio time series
sr: Sample rate
frame_length_ms: Frame length in milliseconds
Returns:
Enhanced audio
"""
frame_len = int(sr * frame_length_ms / 1000)
hop_len = frame_len // 2
# Compute short-time energy for voice activity detection
energy = np.array([
np.sum(audio[i:i+frame_len]**2)
for i in range(0, len(audio) - frame_len, hop_len)
])
# Threshold for voice activity
energy_threshold = np.percentile(energy, 25)
speech_activity = energy > energy_threshold
# Apply mild Wiener-like filtering
enhanced = audio.copy()
for i, (start, end) in enumerate([(j, j+frame_len)
for j in range(0, len(audio) - frame_len, hop_len)]):
frame = audio[start:end]
if speech_activity[i]:
# Preserve speech frames (minimal filtering)
smoothing_factor = 0.1
else:
# Attenuate non-speech frames
smoothing_factor = 0.5
smoothed_frame = signal.savgol_filter(frame, window_length=min(11, len(frame)|1),
polyorder=3, mode='nearest')
enhanced[start:end] = (1 - smoothing_factor) * frame + smoothing_factor * smoothed_frame
return enhanced
def apply_temporal_smoothing(audio, sr, window_ms=5):
"""
Apply temporal smoothing to reduce artifacts and clicks.
"""
window_len = max(3, int(sr * window_ms / 1000) | 1) # Ensure odd
return signal.savgol_filter(audio, window_length=window_len, polyorder=2, mode='nearest')
def enhance_target_speaker(input_file, output_file, enhancement_level='medium'):
"""
Main enhancement pipeline for talker of interest.
Args:
input_file: Path to source audio
output_file: Path to save enhanced audio
enhancement_level: 'light', 'medium', or 'heavy'
"""
# Load audio
audio, sr = sf.read(str(input_file))
original_length = len(audio)
# Ensure mono
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
# Normalize to [-1, 1]
max_val = np.max(np.abs(audio))
if max_val > 0:
audio = audio / max_val
# Apply enhancement pipeline
if enhancement_level in ['medium', 'heavy']:
# Spectral subtraction
noise_factor = 0.3 if enhancement_level == 'medium' else 0.5
audio = apply_spectral_subtraction(audio, sr, noise_estimate_factor=noise_factor)
if enhancement_level in ['light', 'medium']:
# Wiener filtering
audio = apply_wiener_filtering(audio, sr, frame_length_ms=20)
elif enhancement_level == 'heavy':
# Stronger filtering
audio = apply_wiener_filtering(audio, sr, frame_length_ms=10)
# Temporal smoothing
smoothing_ms = 3 if enhancement_level == 'light' else 5
audio = apply_temporal_smoothing(audio, sr, window_ms=smoothing_ms)
# Prevent clipping with gentle compression
audio = np.tanh(audio * 0.95)
# Normalize output
max_val = np.max(np.abs(audio))
if max_val > 0:
audio = 0.95 * audio / max_val # Leave headroom
# Ensure correct length
audio = audio[:original_length]
# Save enhanced audio
sf.write(str(output_file), audio, sr, subtype='PCM_16')
return {
'input_file': str(input_file),
'output_file': str(output_file),
'enhancement_level': enhancement_level,
'sample_rate': sr,
'duration_seconds': len(audio) / sr,
'methods_applied': [
'spectral_subtraction' if enhancement_level in ['medium', 'heavy'] else None,
'wiener_filtering',
'temporal_smoothing',
'soft_clipping'
]
}
if __name__ == '__main__':
import sys
if len(sys.argv) < 2:
print("Usage: python enhance_target.py <input_wav> [output_wav] [level]")
print(" level: 'light' (default), 'medium', or 'heavy'")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2] if len(sys.argv) > 2 else input_file.replace('.wav', '_enhanced.wav')
level = sys.argv[3] if len(sys.argv) > 3 else 'medium'
result = enhance_target_speaker(input_file, output_file, enhancement_level=level)
print("Enhancement complete:")
for key, value in result.items():
print(f" {key}: {value}")