File size: 9,553 Bytes
e391a84
6e84e40
 
e391a84
6e84e40
 
 
 
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
6e84e40
 
e391a84
 
4a1c131
e391a84
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
6e84e40
4a1c131
 
 
 
e391a84
 
 
6e84e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1c131
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
4a1c131
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
 
 
 
 
 
 
6e84e40
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""
infrastructure/processing/sa_helpers.py (FIXED)
───────────────────────────────────────────────
Numba-accelerated signal entropy, plateau detection, and Simulated Annealing logic.

FIX: N < 4 no longer returns hardcoded loss=0.0.
     Feature precomputation and evaluate_sa() are hoisted before the N<4 guard
     so that real signal-quality loss is computed even with 1–3 segments.
"""
from __future__ import annotations

import numpy as np

# Fallback for njit if numba is not installed
try:
    from numba import njit
except ImportError:
    def njit(*args, **kwargs):
        def decorator(func):
            return func
        if len(args) == 1 and callable(args[0]):
            return args[0]
        return decorator


@njit(cache=True)
def _sample_entropy_numba(x: np.ndarray, m: int, r: float) -> float:
    """Sample Entropy via Numba JIT -- O(N^2)."""
    N = len(x)
    B = 0
    A = 0
    for i in range(N - m):
        for j in range(i + 1, N - m):
            match_m = True
            for k in range(m):
                if abs(x[i + k] - x[j + k]) > r:
                    match_m = False
                    break
            if match_m:
                B += 1
                if abs(x[i + m] - x[j + m]) <= r:
                    A += 1
    if B == 0 or A == 0:
        return 0.0
    return -np.log(A / B)


def compute_sample_entropy(signal: np.ndarray, m: int = 2, r_scale: float = 0.2) -> float:
    """Compute sample entropy using Numba compiled function."""
    signal = signal.astype(np.float64)
    std = np.std(signal)
    if std < 1e-8:
        return 0.0
    return float(_sample_entropy_numba(signal, m, r_scale * std))


@njit(cache=True)
def longest_plateau(signal: np.ndarray) -> int:
    """Find the length of the longest consecutive sequence of identical/nearly identical values."""
    if len(signal) < 2:
        return 0
    diff = np.abs(np.diff(signal.astype(np.float64)))
    max_count = 0
    count = 0
    for d in diff:
        count = count + 1 if d < 1e-6 else 0
        if count > max_count:
            max_count = count
    return max_count


def run_simulated_annealing(
    ppg_segments: np.ndarray,
    ecg_segments: np.ndarray,
    sbp_preds: np.ndarray,
    dbp_preds: np.ndarray,
    n_steps: int = 1000,
    alpha: float = 0.05,
) -> dict:
    """
    Run Simulated Annealing to optimize filtering thresholds (lo, hi, max_plat).

    Args:
        ppg_segments: Segmented PPG windows, shape (N, W)
        ecg_segments: Segmented ECG windows, shape (N, W)
        sbp_preds: VGTL-Net predicted SBP for each window, shape (N,)
        dbp_preds: VGTL-Net predicted DBP for each window, shape (N,)
        n_steps: Number of SA iterations (default: 1000)
        alpha: Weight for variance penalty (default: 0.05)

    Returns:
        dict containing optimal thresholds, filtered predictions, yield rate, and SA logs.
    """
    N = len(ppg_segments)

    # ── Guard: empty input ──────────────────────────────────────────────────
    if N == 0:
        return {
            "optimal_lo": 0.0,
            "optimal_hi": 2.5,
            "optimal_max_plateau": 5,
            "best_loss": 1e9,
            "initial_loss": 1e9,
            "n_total_segments": 0,
            "n_clean_segments": 0,
            "yield_rate": 0.0,
            "history": [],
            "clean_indices": [],
        }

    # ── 1. Precompute features for each segment (hoisted before N<4 guard) ──
    ppg_entropies = np.array([
        compute_sample_entropy(ppg_segments[i]) for i in range(N)
    ])
    ppg_plateaus = np.array([
        longest_plateau(ppg_segments[i]) for i in range(N)
    ])
    ecg_plateaus = np.array([
        longest_plateau(ecg_segments[i]) for i in range(N)
    ])

    # ── 2. Internal loss evaluator (hoisted before N<4 guard) ───────────────
    def evaluate_sa(lo: float, hi: float, max_plat: int) -> float:
        matched = []
        for i in range(N):
            se = ppg_entropies[i]
            p_ppg = ppg_plateaus[i]
            p_ecg = ecg_plateaus[i]
            if (lo <= se <= hi) and (p_ppg < max_plat) and (p_ecg < max_plat):
                matched.append(i)

        n_clean = len(matched)
        if n_clean == 0:
            return 1e9

        yield_rate = n_clean / N
        matched_sbp = sbp_preds[matched]
        matched_dbp = dbp_preds[matched]

        std_sbp = float(np.std(matched_sbp)) if len(matched_sbp) > 1 else 0.0
        std_dbp = float(np.std(matched_dbp)) if len(matched_dbp) > 1 else 0.0

        # Loss to minimize: want high yield_rate AND low variance
        loss = -yield_rate + alpha * (std_sbp + std_dbp)

        # Penalise when too few segments survive
        if N >= 4:
            min_clean = max(1, int(0.25 * N))
            if n_clean < min_clean:
                loss += 2.0 * (min_clean - n_clean) / min_clean

        return loss

    # ── Guard: too few segments for meaningful SA optimisation ─────────────
    # We still compute REAL loss (not hardcoded 0) so the chart is informative.
    if N < 4:
        actual_loss = evaluate_sa(0.0, 2.5, 5)

        # Determine which segments pass default thresholds
        clean_indices = []
        for i in range(N):
            se = ppg_entropies[i]
            p_ppg = ppg_plateaus[i]
            p_ecg = ecg_plateaus[i]
            if (0.0 <= se <= 2.5) and (p_ppg < 5) and (p_ecg < 5):
                clean_indices.append(i)

        n_clean = len(clean_indices)
        yield_rate = n_clean / N if N > 0 else 0.0

        return {
            "optimal_lo": 0.0,
            "optimal_hi": 2.5,
            "optimal_max_plateau": 5,
            "best_loss": float(actual_loss),
            "initial_loss": float(actual_loss),
            "n_total_segments": int(N),
            "n_clean_segments": int(n_clean),
            "yield_rate": float(yield_rate),
            "history": [{
                "step": 0,
                "temperature": 1.0,
                "curr_loss": float(f"{actual_loss:.4g}"),
                "best_loss": float(f"{actual_loss:.4g}"),
                "best_lo": 0.0,
                "best_hi": 2.5,
                "best_plat": 5,
            }],
            "clean_indices": clean_indices,
        }

    # ── 3. SA Loop Initialisation ───────────────────────────────────────────
    curr_lo = 0.0
    curr_hi = 2.5
    curr_plat = 5

    curr_loss = evaluate_sa(curr_lo, curr_hi, curr_plat)

    best_lo = curr_lo
    best_hi = curr_hi
    best_plat = curr_plat
    best_loss = curr_loss
    initial_loss = curr_loss

    # SA Hyperparameters
    T_init = 1.0
    T_min = 1e-4

    history = []

    # ── 4. Run SA ───────────────────────────────────────────────────────────
    for step in range(n_steps):
        # Temperature schedule
        T = T_init * ((T_min / T_init) ** (step / (n_steps - 1)))

        # Perturb parameters
        cand_lo = float(np.clip(curr_lo + np.random.normal(0, 0.02), 0.0, 0.5))
        cand_hi = float(np.clip(curr_hi + np.random.normal(0, 0.15), 1.0, 5.0))
        cand_plat = int(np.clip(curr_plat + np.random.choice([-1, 0, 1]), 2, 15))

        cand_loss = evaluate_sa(cand_lo, cand_hi, cand_plat)

        # Acceptance logic
        if cand_loss < curr_loss:
            accept = True
        else:
            accept = float(np.random.random()) < np.exp((curr_loss - cand_loss) / T)

        if accept:
            curr_lo = cand_lo
            curr_hi = cand_hi
            curr_plat = cand_plat
            curr_loss = cand_loss

            if curr_loss < best_loss:
                best_lo = curr_lo
                best_hi = curr_hi
                best_plat = curr_plat
                best_loss = curr_loss

        # Log compact optimisation history
        if step in (0, 999) or (step + 1) % 100 == 0 or (accept and curr_loss == best_loss):
            history.append({
                "step": step,
                "temperature": float(f"{T:.4g}"),
                "curr_loss": float(f"{curr_loss:.4g}"),
                "best_loss": float(f"{best_loss:.4g}"),
                "best_lo": float(f"{best_lo:.4f}"),
                "best_hi": float(f"{best_hi:.4f}"),
                "best_plat": int(best_plat),
            })

    # ── 5. Final filter with optimised thresholds ───────────────────────────
    clean_indices = []
    for i in range(N):
        se = ppg_entropies[i]
        p_ppg = ppg_plateaus[i]
        p_ecg = ecg_plateaus[i]
        if (best_lo <= se <= best_hi) and (p_ppg < best_plat) and (p_ecg < best_plat):
            clean_indices.append(i)

    n_clean = len(clean_indices)
    yield_rate = n_clean / N

    return {
        "optimal_lo": float(best_lo),
        "optimal_hi": float(best_hi),
        "optimal_max_plateau": int(best_plat),
        "best_loss": float(best_loss),
        "initial_loss": float(initial_loss),
        "n_total_segments": int(N),
        "n_clean_segments": int(n_clean),
        "yield_rate": float(yield_rate),
        "history": history,
        "clean_indices": clean_indices,
    }