""" infrastructure/processing/sa_helpers.py (FIXED) ─────────────────────────────────────────────── Numba-accelerated signal entropy, plateau detection, and Simulated Annealing logic. FIX: N < 4 no longer returns hardcoded loss=0.0. Feature precomputation and evaluate_sa() are hoisted before the N<4 guard so that real signal-quality loss is computed even with 1–3 segments. """ from __future__ import annotations import numpy as np # Fallback for njit if numba is not installed try: from numba import njit except ImportError: def njit(*args, **kwargs): def decorator(func): return func if len(args) == 1 and callable(args[0]): return args[0] return decorator @njit(cache=True) def _sample_entropy_numba(x: np.ndarray, m: int, r: float) -> float: """Sample Entropy via Numba JIT -- O(N^2).""" N = len(x) B = 0 A = 0 for i in range(N - m): for j in range(i + 1, N - m): match_m = True for k in range(m): if abs(x[i + k] - x[j + k]) > r: match_m = False break if match_m: B += 1 if abs(x[i + m] - x[j + m]) <= r: A += 1 if B == 0 or A == 0: return 0.0 return -np.log(A / B) def compute_sample_entropy(signal: np.ndarray, m: int = 2, r_scale: float = 0.2) -> float: """Compute sample entropy using Numba compiled function.""" signal = signal.astype(np.float64) std = np.std(signal) if std < 1e-8: return 0.0 return float(_sample_entropy_numba(signal, m, r_scale * std)) @njit(cache=True) def longest_plateau(signal: np.ndarray) -> int: """Find the length of the longest consecutive sequence of identical/nearly identical values.""" if len(signal) < 2: return 0 diff = np.abs(np.diff(signal.astype(np.float64))) max_count = 0 count = 0 for d in diff: count = count + 1 if d < 1e-6 else 0 if count > max_count: max_count = count return max_count def run_simulated_annealing( ppg_segments: np.ndarray, ecg_segments: np.ndarray, sbp_preds: np.ndarray, dbp_preds: np.ndarray, n_steps: int = 1000, alpha: float = 0.05, ) -> dict: """ Run Simulated Annealing to optimize filtering thresholds (lo, hi, max_plat). Args: ppg_segments: Segmented PPG windows, shape (N, W) ecg_segments: Segmented ECG windows, shape (N, W) sbp_preds: VGTL-Net predicted SBP for each window, shape (N,) dbp_preds: VGTL-Net predicted DBP for each window, shape (N,) n_steps: Number of SA iterations (default: 1000) alpha: Weight for variance penalty (default: 0.05) Returns: dict containing optimal thresholds, filtered predictions, yield rate, and SA logs. """ N = len(ppg_segments) # ── Guard: empty input ────────────────────────────────────────────────── if N == 0: return { "optimal_lo": 0.0, "optimal_hi": 2.5, "optimal_max_plateau": 5, "best_loss": 1e9, "initial_loss": 1e9, "n_total_segments": 0, "n_clean_segments": 0, "yield_rate": 0.0, "history": [], "clean_indices": [], } # ── 1. Precompute features for each segment (hoisted before N<4 guard) ── ppg_entropies = np.array([ compute_sample_entropy(ppg_segments[i]) for i in range(N) ]) ppg_plateaus = np.array([ longest_plateau(ppg_segments[i]) for i in range(N) ]) ecg_plateaus = np.array([ longest_plateau(ecg_segments[i]) for i in range(N) ]) # ── 2. Internal loss evaluator (hoisted before N<4 guard) ─────────────── def evaluate_sa(lo: float, hi: float, max_plat: int) -> float: matched = [] for i in range(N): se = ppg_entropies[i] p_ppg = ppg_plateaus[i] p_ecg = ecg_plateaus[i] if (lo <= se <= hi) and (p_ppg < max_plat) and (p_ecg < max_plat): matched.append(i) n_clean = len(matched) if n_clean == 0: return 1e9 yield_rate = n_clean / N matched_sbp = sbp_preds[matched] matched_dbp = dbp_preds[matched] std_sbp = float(np.std(matched_sbp)) if len(matched_sbp) > 1 else 0.0 std_dbp = float(np.std(matched_dbp)) if len(matched_dbp) > 1 else 0.0 # Loss to minimize: want high yield_rate AND low variance loss = -yield_rate + alpha * (std_sbp + std_dbp) # Penalise when too few segments survive if N >= 4: min_clean = max(1, int(0.25 * N)) if n_clean < min_clean: loss += 2.0 * (min_clean - n_clean) / min_clean return loss # ── Guard: too few segments for meaningful SA optimisation ───────────── # We still compute REAL loss (not hardcoded 0) so the chart is informative. if N < 4: actual_loss = evaluate_sa(0.0, 2.5, 5) # Determine which segments pass default thresholds clean_indices = [] for i in range(N): se = ppg_entropies[i] p_ppg = ppg_plateaus[i] p_ecg = ecg_plateaus[i] if (0.0 <= se <= 2.5) and (p_ppg < 5) and (p_ecg < 5): clean_indices.append(i) n_clean = len(clean_indices) yield_rate = n_clean / N if N > 0 else 0.0 return { "optimal_lo": 0.0, "optimal_hi": 2.5, "optimal_max_plateau": 5, "best_loss": float(actual_loss), "initial_loss": float(actual_loss), "n_total_segments": int(N), "n_clean_segments": int(n_clean), "yield_rate": float(yield_rate), "history": [{ "step": 0, "temperature": 1.0, "curr_loss": float(f"{actual_loss:.4g}"), "best_loss": float(f"{actual_loss:.4g}"), "best_lo": 0.0, "best_hi": 2.5, "best_plat": 5, }], "clean_indices": clean_indices, } # ── 3. SA Loop Initialisation ─────────────────────────────────────────── curr_lo = 0.0 curr_hi = 2.5 curr_plat = 5 curr_loss = evaluate_sa(curr_lo, curr_hi, curr_plat) best_lo = curr_lo best_hi = curr_hi best_plat = curr_plat best_loss = curr_loss initial_loss = curr_loss # SA Hyperparameters T_init = 1.0 T_min = 1e-4 history = [] # ── 4. Run SA ─────────────────────────────────────────────────────────── for step in range(n_steps): # Temperature schedule T = T_init * ((T_min / T_init) ** (step / (n_steps - 1))) # Perturb parameters cand_lo = float(np.clip(curr_lo + np.random.normal(0, 0.02), 0.0, 0.5)) cand_hi = float(np.clip(curr_hi + np.random.normal(0, 0.15), 1.0, 5.0)) cand_plat = int(np.clip(curr_plat + np.random.choice([-1, 0, 1]), 2, 15)) cand_loss = evaluate_sa(cand_lo, cand_hi, cand_plat) # Acceptance logic if cand_loss < curr_loss: accept = True else: accept = float(np.random.random()) < np.exp((curr_loss - cand_loss) / T) if accept: curr_lo = cand_lo curr_hi = cand_hi curr_plat = cand_plat curr_loss = cand_loss if curr_loss < best_loss: best_lo = curr_lo best_hi = curr_hi best_plat = curr_plat best_loss = curr_loss # Log compact optimisation history if step in (0, 999) or (step + 1) % 100 == 0 or (accept and curr_loss == best_loss): history.append({ "step": step, "temperature": float(f"{T:.4g}"), "curr_loss": float(f"{curr_loss:.4g}"), "best_loss": float(f"{best_loss:.4g}"), "best_lo": float(f"{best_lo:.4f}"), "best_hi": float(f"{best_hi:.4f}"), "best_plat": int(best_plat), }) # ── 5. Final filter with optimised thresholds ─────────────────────────── clean_indices = [] for i in range(N): se = ppg_entropies[i] p_ppg = ppg_plateaus[i] p_ecg = ecg_plateaus[i] if (best_lo <= se <= best_hi) and (p_ppg < best_plat) and (p_ecg < best_plat): clean_indices.append(i) n_clean = len(clean_indices) yield_rate = n_clean / N return { "optimal_lo": float(best_lo), "optimal_hi": float(best_hi), "optimal_max_plateau": int(best_plat), "best_loss": float(best_loss), "initial_loss": float(initial_loss), "n_total_segments": int(N), "n_clean_segments": int(n_clean), "yield_rate": float(yield_rate), "history": history, "clean_indices": clean_indices, }