Premchan369 commited on
Commit
201696d
·
verified ·
1 Parent(s): c5a733a

Add wavelet denoising engine with adaptive parameter selection

Browse files
Files changed (1) hide show
  1. wavelet_denoising.py +403 -0
wavelet_denoising.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wavelet Denoising for Financial Time Series
2
+
3
+ Based on Lopez Gil et al. 2024 (xLSTM-TS paper, arxiv:2408.12408):
4
+ Wavelet denoising improved prediction accuracy by 5-10% across ALL models tested.
5
+
6
+ This is NOT optional. Without wavelet preprocessing, your models are
7
+ training on noise instead of signal.
8
+ """
9
+ import numpy as np
10
+ import pandas as pd
11
+ from typing import Optional, Tuple, Dict, List
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ try:
16
+ import pywt
17
+ PYWT_AVAILABLE = True
18
+ except ImportError:
19
+ PYWT_AVAILABLE = False
20
+ print("WARNING: pywt not available. Install with: pip install PyWavelets")
21
+
22
+
23
+ class WaveletDenoiser:
24
+ """
25
+ Wavelet-based signal denoising for financial time series.
26
+
27
+ The key insight: Financial time series contain:
28
+ 1. Low-frequency signal (trend, regime changes)
29
+ 2. High-frequency noise (intraday noise, microstructure effects)
30
+
31
+ Wavelets separate these naturally. We keep the low frequencies
32
+ (signal) and discard/shrink the high frequencies (noise).
33
+
34
+ Wavelet choice: db4 (Daubechies 4) is optimal for financial data
35
+ per Lopez Gil et al. because:
36
+ - Compact support (localizes in time)
37
+ - Orthogonal (no redundancy)
38
+ - 4 vanishing moments (captures polynomial trends well)
39
+ """
40
+
41
+ def __init__(self, wavelet: str = 'db4', mode: str = 'symmetric',
42
+ threshold_mode: str = 'soft', level: Optional[int] = None):
43
+ self.wavelet = wavelet
44
+ self.mode = mode
45
+ self.threshold_mode = threshold_mode
46
+ self.level = level
47
+
48
+ def _get_max_level(self, signal_length: int) -> int:
49
+ """Maximum decomposition level for given signal length"""
50
+ if not PYWT_AVAILABLE:
51
+ return 1
52
+ return pywt.dwt_max_level(signal_length, pywt.Wavelet(self.wavelet).dec_len)
53
+
54
+ def _estimate_noise_sigma(self, detail_coeffs: np.ndarray) -> float:
55
+ """
56
+ Estimate noise standard deviation from finest detail coefficients.
57
+
58
+ Uses MAD (Median Absolute Deviation) estimator:
59
+ sigma = median(|coeffs|) / 0.6745
60
+
61
+ This is robust to outliers — critical for financial data with fat tails.
62
+ """
63
+ mad = np.median(np.abs(detail_coeffs))
64
+ return mad / 0.6745 if mad > 0 else 1.0
65
+
66
+ def _universal_threshold(self, coeffs: np.ndarray,
67
+ sigma: Optional[float] = None) -> float:
68
+ """
69
+ Donoho-Johnstone universal threshold:
70
+ lambda = sigma * sqrt(2 * log(N))
71
+
72
+ Asymptotically optimal for Gaussian noise. For financial data,
73
+ we use the robust MAD estimate for sigma.
74
+ """
75
+ n = len(coeffs)
76
+ if sigma is None:
77
+ sigma = self._estimate_noise_sigma(coeffs)
78
+
79
+ return sigma * np.sqrt(2 * np.log(n))
80
+
81
+ def _sure_threshold(self, coeffs: np.ndarray) -> float:
82
+ """
83
+ Stein's Unbiased Risk Estimate (SURE) threshold.
84
+
85
+ More adaptive than universal threshold — better when noise level
86
+ varies across the signal. Slightly more computationally expensive.
87
+ """
88
+ n = len(coeffs)
89
+ coeffs_sorted = np.sort(np.abs(coeffs)) ** 2
90
+
91
+ risks = np.zeros(n)
92
+ for i in range(n):
93
+ risks[i] = (n - 2 * (i + 1) + (n - (i + 1)) * coeffs_sorted[i] +
94
+ np.sum(coeffs_sorted[:i+1]))
95
+
96
+ min_risk_idx = np.argmin(risks)
97
+ return np.sqrt(coeffs_sorted[min_risk_idx])
98
+
99
+ def denoise(self, signal: np.ndarray) -> np.ndarray:
100
+ """
101
+ Denoise a 1D signal using wavelet thresholding.
102
+
103
+ Args:
104
+ signal: 1D numpy array
105
+
106
+ Returns:
107
+ Denoised signal
108
+ """
109
+ if not PYWT_AVAILABLE:
110
+ # Fallback: simple moving average smoothing
111
+ return pd.Series(signal).rolling(5, center=True, min_periods=1).mean().values
112
+
113
+ # Determine decomposition level
114
+ max_level = self._get_max_level(len(signal))
115
+ level = self.level or min(4, max_level)
116
+ level = max(1, min(level, max_level))
117
+
118
+ # Wavelet decomposition
119
+ coeffs = pywt.wavedec(signal, self.wavelet, mode=self.mode, level=level)
120
+
121
+ # Threshold detail coefficients (keep approximation)
122
+ denoised_coeffs = [coeffs[0]] # approximation (low-frequency signal)
123
+
124
+ for detail in coeffs[1:]:
125
+ # Estimate noise from this level's detail coefficients
126
+ sigma = self._estimate_noise_sigma(detail)
127
+ threshold = self._universal_threshold(detail, sigma)
128
+
129
+ # Apply threshold
130
+ if self.threshold_mode == 'soft':
131
+ denoised_detail = pywt.threshold(detail, threshold, mode='soft')
132
+ elif self.threshold_mode == 'hard':
133
+ denoised_detail = pywt.threshold(detail, threshold, mode='hard')
134
+ elif self.threshold_mode == 'garotte':
135
+ # Firm threshold (compromise between soft and hard)
136
+ denoised_detail = pywt.threshold(detail, threshold, mode='greater')
137
+ else:
138
+ denoised_detail = pywt.threshold(detail, threshold, mode='soft')
139
+
140
+ denoised_coeffs.append(denoised_detail)
141
+
142
+ # Reconstruct signal
143
+ denoised = pywt.waverec(denoised_coeffs, self.wavelet, mode=self.mode)
144
+
145
+ # wavrec may return slightly longer signal due to padding
146
+ return denoised[:len(signal)]
147
+
148
+ def denoise_dataframe(self, df: pd.DataFrame,
149
+ columns: Optional[List[str]] = None) -> pd.DataFrame:
150
+ """
151
+ Denoise multiple columns of a DataFrame.
152
+
153
+ Args:
154
+ df: DataFrame with time series data
155
+ columns: List of columns to denoise (None = all numeric)
156
+
157
+ Returns:
158
+ DataFrame with denoised columns (original + _denoised suffix)
159
+ """
160
+ if columns is None:
161
+ columns = df.select_dtypes(include=[np.number]).columns.tolist()
162
+
163
+ result = df.copy()
164
+ for col in columns:
165
+ if col in df.columns:
166
+ signal = df[col].values
167
+ # Handle NaN
168
+ nan_mask = np.isnan(signal)
169
+ if nan_mask.any():
170
+ signal = pd.Series(signal).interpolate().fillna(method='bfill').fillna(method='ffill').values
171
+
172
+ denoised = self.denoise(signal)
173
+ result[f'{col}_denoised'] = denoised
174
+
175
+ return result
176
+
177
+ def denoise_multivariate(self, signals: np.ndarray,
178
+ axis: int = 0) -> np.ndarray:
179
+ """
180
+ Denoise each column/row of a multivariate signal array independently.
181
+
182
+ Args:
183
+ signals: 2D array (samples x features)
184
+ axis: 0 = denoise each column, 1 = denoise each row
185
+
186
+ Returns:
187
+ Denoised array
188
+ """
189
+ if signals.ndim != 2:
190
+ raise ValueError("signals must be 2D")
191
+
192
+ if axis == 1:
193
+ signals = signals.T
194
+
195
+ denoised = np.zeros_like(signals)
196
+ for i in range(signals.shape[1]):
197
+ denoised[:, i] = self.denoise(signals[:, i])
198
+
199
+ if axis == 1:
200
+ denoised = denoised.T
201
+
202
+ return denoised
203
+
204
+ def extract_features(self, signal: np.ndarray, level: Optional[int] = None) -> Dict[str, np.ndarray]:
205
+ """
206
+ Extract wavelet-based features from a signal.
207
+
208
+ These features capture multi-scale information:
209
+ - Energy by frequency band
210
+ - Entropy (complexity measure)
211
+ - Coefficient statistics
212
+
213
+ Useful as additional features for ML models.
214
+ """
215
+ if not PYWT_AVAILABLE:
216
+ return {}
217
+
218
+ max_level = self._get_max_level(len(signal))
219
+ level = level or min(4, max_level)
220
+ level = max(1, min(level, max_level))
221
+
222
+ coeffs = pywt.wavedec(signal, self.wavelet, mode=self.mode, level=level)
223
+
224
+ features = {}
225
+
226
+ # Approximation features
227
+ approx = coeffs[0]
228
+ features['approx_mean'] = np.mean(approx)
229
+ features['approx_std'] = np.std(approx)
230
+ features['approx_energy'] = np.sum(approx ** 2)
231
+
232
+ # Detail features by level
233
+ for i, detail in enumerate(coeffs[1:], 1):
234
+ features[f'detail_{i}_mean'] = np.mean(detail)
235
+ features[f'detail_{i}_std'] = np.std(detail)
236
+ features[f'detail_{i}_energy'] = np.sum(detail ** 2)
237
+ features[f'detail_{i}_skew'] = self._skewness(detail)
238
+ features[f'detail_{i}_kurt'] = self._kurtosis(detail)
239
+
240
+ # Multi-scale energy ratio
241
+ total_energy = sum(np.sum(c ** 2) for c in coeffs)
242
+ if total_energy > 0:
243
+ features['approx_energy_ratio'] = features['approx_energy'] / total_energy
244
+ for i in range(1, len(coeffs)):
245
+ features[f'detail_{i}_energy_ratio'] = np.sum(coeffs[i] ** 2) / total_energy
246
+
247
+ # Wavelet entropy (measure of complexity)
248
+ energies = [np.sum(c ** 2) for c in coeffs]
249
+ energies = np.array(energies) / (sum(energies) + 1e-10)
250
+ features['wavelet_entropy'] = -np.sum(energies * np.log(energies + 1e-10))
251
+
252
+ return features
253
+
254
+ @staticmethod
255
+ def _skewness(x: np.ndarray) -> float:
256
+ x = x - np.mean(x)
257
+ n = len(x)
258
+ if n < 3 or np.std(x) == 0:
259
+ return 0.0
260
+ return np.sum(x ** 3) / (n * np.std(x) ** 3)
261
+
262
+ @staticmethod
263
+ def _kurtosis(x: np.ndarray) -> float:
264
+ x = x - np.mean(x)
265
+ n = len(x)
266
+ if n < 4 or np.std(x) == 0:
267
+ return 0.0
268
+ return np.sum(x ** 4) / (n * np.std(x) ** 4) - 3
269
+
270
+
271
+ class AdaptiveWaveletDenoiser:
272
+ """
273
+ Adaptive wavelet denoising that selects optimal parameters per signal.
274
+
275
+ Instead of using fixed wavelet and level, this tries multiple combinations
276
+ and selects the one that best separates signal from noise.
277
+
278
+ Selection criterion: SNR (Signal-to-Noise Ratio) maximization.
279
+ """
280
+
281
+ def __init__(self, wavelets: Optional[List[str]] = None,
282
+ levels: Optional[List[int]] = None,
283
+ threshold_modes: Optional[List[str]] = None):
284
+ self.wavelets = wavelets or ['db2', 'db4', 'db6', 'sym4', 'coif2']
285
+ self.levels = levels or [2, 3, 4, 5]
286
+ self.threshold_modes = threshold_modes or ['soft', 'hard']
287
+ self.best_params = None
288
+
289
+ def _snr(self, signal: np.ndarray, denoised: np.ndarray) -> float:
290
+ """Estimate signal-to-noise ratio"""
291
+ noise = signal - denoised
292
+ signal_power = np.sum(denoised ** 2)
293
+ noise_power = np.sum(noise ** 2) + 1e-10
294
+ return 10 * np.log10(signal_power / noise_power)
295
+
296
+ def fit(self, signal: np.ndarray) -> Dict:
297
+ """
298
+ Find optimal denoising parameters for a signal.
299
+
300
+ Uses grid search over wavelets, levels, and threshold modes.
301
+ Selects combination with highest estimated SNR.
302
+ """
303
+ best_snr = -np.inf
304
+ best_params = {}
305
+ best_denoised = signal.copy()
306
+
307
+ if not PYWT_AVAILABLE:
308
+ return {'wavelet': 'none', 'level': 1, 'threshold_mode': 'none'}
309
+
310
+ for wavelet in self.wavelets:
311
+ try:
312
+ max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet).dec_len)
313
+ valid_levels = [l for l in self.levels if l <= max_level]
314
+ if not valid_levels:
315
+ valid_levels = [1]
316
+
317
+ for level in valid_levels:
318
+ for mode in self.threshold_modes:
319
+ try:
320
+ denoiser = WaveletDenoiser(
321
+ wavelet=wavelet,
322
+ level=level,
323
+ threshold_mode=mode
324
+ )
325
+ denoised = denoiser.denoise(signal)
326
+ snr = self._snr(signal, denoised)
327
+
328
+ if snr > best_snr:
329
+ best_snr = snr
330
+ best_params = {
331
+ 'wavelet': wavelet,
332
+ 'level': level,
333
+ 'threshold_mode': mode,
334
+ 'snr': snr
335
+ }
336
+ best_denoised = denoised
337
+ except:
338
+ continue
339
+ except:
340
+ continue
341
+
342
+ self.best_params = best_params
343
+ return best_params
344
+
345
+ def denoise(self, signal: np.ndarray) -> np.ndarray:
346
+ """Denoise using best fitted parameters"""
347
+ if self.best_params is None:
348
+ self.fit(signal)
349
+
350
+ denoiser = WaveletDenoiser(
351
+ wavelet=self.best_params.get('wavelet', 'db4'),
352
+ level=self.best_params.get('level', 4),
353
+ threshold_mode=self.best_params.get('threshold_mode', 'soft')
354
+ )
355
+ return denoiser.denoise(signal)
356
+
357
+
358
+ def benchmark_denoising(signal: np.ndarray, noise_level: float = 0.1) -> Dict:
359
+ """
360
+ Benchmark denoising on a known signal with added noise.
361
+
362
+ Returns MSE, SNR, correlation with true signal.
363
+ """
364
+ # Add noise
365
+ noisy = signal + np.random.randn(len(signal)) * noise_level * np.std(signal)
366
+
367
+ # Denoise
368
+ denoiser = WaveletDenoiser(wavelet='db4', level=4, threshold_mode='soft')
369
+ denoised = denoiser.denoise(noisy)
370
+
371
+ # Metrics
372
+ mse_noisy = np.mean((noisy - signal) ** 2)
373
+ mse_denoised = np.mean((denoised - signal) ** 2)
374
+
375
+ corr_noisy = np.corrcoef(noisy, signal)[0, 1]
376
+ corr_denoised = np.corrcoef(denoised, signal)[0, 1]
377
+
378
+ snr_noisy = 10 * np.log10(np.sum(signal ** 2) / np.sum((noisy - signal) ** 2))
379
+ snr_denoised = 10 * np.log10(np.sum(signal ** 2) / np.sum((denoised - signal) ** 2))
380
+
381
+ return {
382
+ 'mse_noisy': mse_noisy,
383
+ 'mse_denoised': mse_denoised,
384
+ 'improvement_factor': mse_noisy / (mse_denoised + 1e-10),
385
+ 'corr_noisy': corr_noisy,
386
+ 'corr_denoised': corr_denoised,
387
+ 'snr_noisy': snr_noisy,
388
+ 'snr_denoised': snr_denoised,
389
+ 'snr_improvement': snr_denoised - snr_noisy
390
+ }
391
+
392
+
393
+ if __name__ == '__main__':
394
+ # Test with synthetic signal
395
+ np.random.seed(42)
396
+ t = np.linspace(0, 4*np.pi, 1000)
397
+ true_signal = np.sin(t) + 0.5 * np.sin(3*t)
398
+
399
+ results = benchmark_denoising(true_signal, noise_level=0.3)
400
+ print("Wavelet Denoising Benchmark")
401
+ print("=" * 50)
402
+ for k, v in results.items():
403
+ print(f" {k}: {v:.4f}")