import io import json from typing import Any, Dict, List import gradio as gr import matplotlib.pyplot as plt from app.ml.gating import gate_signal from app.ml.inference import infer_ecg, load_model from app.rules.engine import evaluate_ecg_rules # Preload model (uses ./checkpoints/ecg_classifier.pt if present) load_model() def parse_signal(text: str | List[float]) -> List[float]: if isinstance(text, list): return [float(x) for x in text] try: return [float(x) for x in json.loads(text)] except Exception: raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]") def run_infer(signal_text: str) -> Dict[str, Any]: sig = parse_signal(signal_text) gated, gating_meta = gate_signal(sig, return_windows=True) model_output: Dict[str, Any] = infer_ecg( gated, original_len=len(sig), gating_meta=gating_meta, ) patient_context = {"patient_id": "demo"} rules_result = evaluate_ecg_rules(patient_context, model_output) explanations = [*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []), *rules_result.get("explanations", [])] return { "label": model_output.get("label"), "score": round(float(model_output.get("score", 0.0)), 3), "hr": model_output.get("hr"), "alert_level": rules_result.get("alert_level", "none"), "gated_ratio": round(model_output.get("gated_ratio", 1.0), 3), "gating": gating_meta, "explanations": explanations, } def plot_gating(signal_text: str): sig = parse_signal(signal_text) gated, meta = gate_signal(sig, return_windows=True) fig, axes = plt.subplots(2, 1, figsize=(6, 4)) axes[0].plot(sig, color="#0066ff", linewidth=1) axes[0].set_title("Raw signal") axes[1].plot(gated, color="#ff6600", linewidth=1) axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})") fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120) plt.close(fig) buf.seek(0) return buf demo_normal = [0.05 for _ in range(256)] demo_afib = [0.3 for _ in range(256)] with gr.Blocks(title="Sundew ECG Demo") as demo: gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules") with gr.Tabs(): with gr.Tab("Upload/Infer"): inp = gr.Textbox( label="ECG samples (JSON list)", value=json.dumps(demo_afib[:128]), ) out = gr.JSON(label="Inference") btn = gr.Button("Run") btn.click(run_infer, inputs=inp, outputs=out) with gr.Tab("Gating Preview"): inp2 = gr.Textbox( label="ECG samples (JSON list)", value=json.dumps(demo_afib[:128]), ) img = gr.Image(type="filepath", label="Raw vs Gated") btn2 = gr.Button("Show gating") btn2.click(plot_gating, inputs=inp2, outputs=img) with gr.Tab("Demos"): out_demo = gr.JSON() btn_n = gr.Button("Normal") btn_a = gr.Button("Arrhythmia-ish") hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False) hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False) btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo) btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo) if __name__ == "__main__": demo.launch()