LIBRE / src /infrastructure /processing /sa_helpers.py
RyZ
fix: fix some little mistake on SA
6e84e40
Raw
History Blame Contribute Delete
9.55 kB
"""
infrastructure/processing/sa_helpers.py (FIXED)
───────────────────────────────────────────────
Numba-accelerated signal entropy, plateau detection, and Simulated Annealing logic.
FIX: N < 4 no longer returns hardcoded loss=0.0.
Feature precomputation and evaluate_sa() are hoisted before the N<4 guard
so that real signal-quality loss is computed even with 1–3 segments.
"""
from __future__ import annotations
import numpy as np
# Fallback for njit if numba is not installed
try:
from numba import njit
except ImportError:
def njit(*args, **kwargs):
def decorator(func):
return func
if len(args) == 1 and callable(args[0]):
return args[0]
return decorator
@njit(cache=True)
def _sample_entropy_numba(x: np.ndarray, m: int, r: float) -> float:
"""Sample Entropy via Numba JIT -- O(N^2)."""
N = len(x)
B = 0
A = 0
for i in range(N - m):
for j in range(i + 1, N - m):
match_m = True
for k in range(m):
if abs(x[i + k] - x[j + k]) > r:
match_m = False
break
if match_m:
B += 1
if abs(x[i + m] - x[j + m]) <= r:
A += 1
if B == 0 or A == 0:
return 0.0
return -np.log(A / B)
def compute_sample_entropy(signal: np.ndarray, m: int = 2, r_scale: float = 0.2) -> float:
"""Compute sample entropy using Numba compiled function."""
signal = signal.astype(np.float64)
std = np.std(signal)
if std < 1e-8:
return 0.0
return float(_sample_entropy_numba(signal, m, r_scale * std))
@njit(cache=True)
def longest_plateau(signal: np.ndarray) -> int:
"""Find the length of the longest consecutive sequence of identical/nearly identical values."""
if len(signal) < 2:
return 0
diff = np.abs(np.diff(signal.astype(np.float64)))
max_count = 0
count = 0
for d in diff:
count = count + 1 if d < 1e-6 else 0
if count > max_count:
max_count = count
return max_count
def run_simulated_annealing(
ppg_segments: np.ndarray,
ecg_segments: np.ndarray,
sbp_preds: np.ndarray,
dbp_preds: np.ndarray,
n_steps: int = 1000,
alpha: float = 0.05,
) -> dict:
"""
Run Simulated Annealing to optimize filtering thresholds (lo, hi, max_plat).
Args:
ppg_segments: Segmented PPG windows, shape (N, W)
ecg_segments: Segmented ECG windows, shape (N, W)
sbp_preds: VGTL-Net predicted SBP for each window, shape (N,)
dbp_preds: VGTL-Net predicted DBP for each window, shape (N,)
n_steps: Number of SA iterations (default: 1000)
alpha: Weight for variance penalty (default: 0.05)
Returns:
dict containing optimal thresholds, filtered predictions, yield rate, and SA logs.
"""
N = len(ppg_segments)
# ── Guard: empty input ──────────────────────────────────────────────────
if N == 0:
return {
"optimal_lo": 0.0,
"optimal_hi": 2.5,
"optimal_max_plateau": 5,
"best_loss": 1e9,
"initial_loss": 1e9,
"n_total_segments": 0,
"n_clean_segments": 0,
"yield_rate": 0.0,
"history": [],
"clean_indices": [],
}
# ── 1. Precompute features for each segment (hoisted before N<4 guard) ──
ppg_entropies = np.array([
compute_sample_entropy(ppg_segments[i]) for i in range(N)
])
ppg_plateaus = np.array([
longest_plateau(ppg_segments[i]) for i in range(N)
])
ecg_plateaus = np.array([
longest_plateau(ecg_segments[i]) for i in range(N)
])
# ── 2. Internal loss evaluator (hoisted before N<4 guard) ───────────────
def evaluate_sa(lo: float, hi: float, max_plat: int) -> float:
matched = []
for i in range(N):
se = ppg_entropies[i]
p_ppg = ppg_plateaus[i]
p_ecg = ecg_plateaus[i]
if (lo <= se <= hi) and (p_ppg < max_plat) and (p_ecg < max_plat):
matched.append(i)
n_clean = len(matched)
if n_clean == 0:
return 1e9
yield_rate = n_clean / N
matched_sbp = sbp_preds[matched]
matched_dbp = dbp_preds[matched]
std_sbp = float(np.std(matched_sbp)) if len(matched_sbp) > 1 else 0.0
std_dbp = float(np.std(matched_dbp)) if len(matched_dbp) > 1 else 0.0
# Loss to minimize: want high yield_rate AND low variance
loss = -yield_rate + alpha * (std_sbp + std_dbp)
# Penalise when too few segments survive
if N >= 4:
min_clean = max(1, int(0.25 * N))
if n_clean < min_clean:
loss += 2.0 * (min_clean - n_clean) / min_clean
return loss
# ── Guard: too few segments for meaningful SA optimisation ─────────────
# We still compute REAL loss (not hardcoded 0) so the chart is informative.
if N < 4:
actual_loss = evaluate_sa(0.0, 2.5, 5)
# Determine which segments pass default thresholds
clean_indices = []
for i in range(N):
se = ppg_entropies[i]
p_ppg = ppg_plateaus[i]
p_ecg = ecg_plateaus[i]
if (0.0 <= se <= 2.5) and (p_ppg < 5) and (p_ecg < 5):
clean_indices.append(i)
n_clean = len(clean_indices)
yield_rate = n_clean / N if N > 0 else 0.0
return {
"optimal_lo": 0.0,
"optimal_hi": 2.5,
"optimal_max_plateau": 5,
"best_loss": float(actual_loss),
"initial_loss": float(actual_loss),
"n_total_segments": int(N),
"n_clean_segments": int(n_clean),
"yield_rate": float(yield_rate),
"history": [{
"step": 0,
"temperature": 1.0,
"curr_loss": float(f"{actual_loss:.4g}"),
"best_loss": float(f"{actual_loss:.4g}"),
"best_lo": 0.0,
"best_hi": 2.5,
"best_plat": 5,
}],
"clean_indices": clean_indices,
}
# ── 3. SA Loop Initialisation ───────────────────────────────────────────
curr_lo = 0.0
curr_hi = 2.5
curr_plat = 5
curr_loss = evaluate_sa(curr_lo, curr_hi, curr_plat)
best_lo = curr_lo
best_hi = curr_hi
best_plat = curr_plat
best_loss = curr_loss
initial_loss = curr_loss
# SA Hyperparameters
T_init = 1.0
T_min = 1e-4
history = []
# ── 4. Run SA ───────────────────────────────────────────────────────────
for step in range(n_steps):
# Temperature schedule
T = T_init * ((T_min / T_init) ** (step / (n_steps - 1)))
# Perturb parameters
cand_lo = float(np.clip(curr_lo + np.random.normal(0, 0.02), 0.0, 0.5))
cand_hi = float(np.clip(curr_hi + np.random.normal(0, 0.15), 1.0, 5.0))
cand_plat = int(np.clip(curr_plat + np.random.choice([-1, 0, 1]), 2, 15))
cand_loss = evaluate_sa(cand_lo, cand_hi, cand_plat)
# Acceptance logic
if cand_loss < curr_loss:
accept = True
else:
accept = float(np.random.random()) < np.exp((curr_loss - cand_loss) / T)
if accept:
curr_lo = cand_lo
curr_hi = cand_hi
curr_plat = cand_plat
curr_loss = cand_loss
if curr_loss < best_loss:
best_lo = curr_lo
best_hi = curr_hi
best_plat = curr_plat
best_loss = curr_loss
# Log compact optimisation history
if step in (0, 999) or (step + 1) % 100 == 0 or (accept and curr_loss == best_loss):
history.append({
"step": step,
"temperature": float(f"{T:.4g}"),
"curr_loss": float(f"{curr_loss:.4g}"),
"best_loss": float(f"{best_loss:.4g}"),
"best_lo": float(f"{best_lo:.4f}"),
"best_hi": float(f"{best_hi:.4f}"),
"best_plat": int(best_plat),
})
# ── 5. Final filter with optimised thresholds ───────────────────────────
clean_indices = []
for i in range(N):
se = ppg_entropies[i]
p_ppg = ppg_plateaus[i]
p_ecg = ecg_plateaus[i]
if (best_lo <= se <= best_hi) and (p_ppg < best_plat) and (p_ecg < best_plat):
clean_indices.append(i)
n_clean = len(clean_indices)
yield_rate = n_clean / N
return {
"optimal_lo": float(best_lo),
"optimal_hi": float(best_hi),
"optimal_max_plateau": int(best_plat),
"best_loss": float(best_loss),
"initial_loss": float(initial_loss),
"n_total_segments": int(N),
"n_clean_segments": int(n_clean),
"yield_rate": float(yield_rate),
"history": history,
"clean_indices": clean_indices,
}