import io import json import os import sys import math from typing import Any, Dict, List import gradio as gr import matplotlib.pyplot as plt import numpy as np import matplotlib.image as mpimg import pandas as pd # Ensure local package is importable when running in Hugging Face Spaces ROOT = os.path.dirname(os.path.abspath(__file__)) if ROOT not in sys.path: sys.path.insert(0, ROOT) from app.ml.gating import gate_signal from app.ml.inference import infer_ecg, load_model from app.rules.engine import evaluate_ecg_rules # Preload model (uses ./checkpoints/ecg_classifier.pt if present) load_model() def parse_signal(text: str | List[float]) -> List[float]: if isinstance(text, list): return [float(x) for x in text] try: return [float(x) for x in json.loads(text)] except Exception: raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]") def run_infer(signal_text: str) -> Dict[str, Any]: sig = parse_signal(signal_text) gated, gating_meta = gate_signal(sig, return_windows=True) model_output: Dict[str, Any] = infer_ecg( gated, original_len=len(sig), gating_meta=gating_meta, ) patient_context = {"patient_id": "demo"} rules_result = evaluate_ecg_rules(patient_context, model_output) explanations = [ *(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []), *rules_result.get("explanations", []), ] summary = f"Windows kept: {gating_meta.get('selected_windows',0)}/{gating_meta.get('total_windows',0)} • ratio={gating_meta.get('ratio',1):.2f}" return { "label": model_output.get("label"), "score": round(float(model_output.get("score", 0.0)), 3), "hr": model_output.get("hr"), "alert_level": rules_result.get("alert_level", "none"), "gated_ratio": round(model_output.get("gated_ratio", 1.0), 3), "gating": gating_meta, "gating_summary": summary, "explanations": explanations, } def plot_gating(signal_text: str): sig = parse_signal(signal_text) gated, meta = gate_signal(sig, return_windows=True) fig, axes = plt.subplots(2, 1, figsize=(6, 4)) axes[0].plot(sig, color="#0066ff", linewidth=1) axes[0].set_title("Raw signal") axes[1].plot(gated, color="#ff6600", linewidth=1) axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})") fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120) plt.close(fig) buf.seek(0) np_img = mpimg.imread(buf) windows = meta.get("windows", []) table_rows = [] seen = set() for w in windows: key = (w.get("start"), w.get("end"), bool(w.get("selected")), bool(w.get("forced", False))) if key in seen: continue seen.add(key) table_rows.append( [ w.get("start"), w.get("end"), round(float(w.get("significance", 0.0)), 3), round(float(w.get("probability", 0.0)), 3), bool(w.get("selected")), bool(w.get("forced", False)), ] ) df = pd.DataFrame(table_rows, columns=["start", "end", "significance", "prob", "selected", "forced"]) summary = f"Windows kept: {meta.get('selected_windows',0)}/{meta.get('total_windows',0)} • ratio={meta.get('ratio',1):.2f}" return np_img, summary, df # Demo signals with more structure so gating can skip/keep meaningfully demo_normal = [0.05 * math.sin(2 * math.pi * 2 * (i / 256)) for i in range(256)] demo_afib = [ 0.25 * math.sin(2 * math.pi * 6 * (i / 256)) + 0.05 * math.sin(2 * math.pi * 15 * (i / 256)) + (0.15 if i % 40 == 0 else 0.0) for i in range(256) ] demo_noise = [0.02 * math.sin(2 * math.pi * 1 * (i / 256)) + (0.01 if i % 13 == 0 else 0.0) for i in range(256)] with gr.Blocks(title="Sundew ECG Demo") as demo: gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules") with gr.Tabs(): with gr.Tab("Upload/Infer"): inp = gr.Textbox( label="ECG samples (JSON list)", value=json.dumps(demo_afib[:128]), ) out = gr.JSON(label="Inference") btn = gr.Button("Run") btn.click(run_infer, inputs=inp, outputs=out) with gr.Tab("Gating Preview"): inp2 = gr.Textbox( label="ECG samples (JSON list)", value=json.dumps(demo_afib[:128]), ) img = gr.Image(type="numpy", label="Raw vs Gated") summary_box = gr.Textbox(label="Gating summary") table = gr.Dataframe( headers=["start", "end", "significance", "prob", "selected", "forced"], datatype=["number", "number", "number", "number", "bool", "bool"], wrap=True, ) btn2 = gr.Button("Show gating") btn2.click(plot_gating, inputs=inp2, outputs=[img, summary_box, table]) with gr.Tab("Demos"): out_demo = gr.JSON() btn_n = gr.Button("Normal") btn_a = gr.Button("Arrhythmia-ish") btn_noise = gr.Button("Noisy baseline") hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False) hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False) hidden_noise = gr.Textbox(value=json.dumps(demo_noise), visible=False) btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo) btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo) btn_noise.click(run_infer, inputs=hidden_noise, outputs=out_demo) if __name__ == "__main__": demo.launch()