Spaces:
Running
Running
| import io | |
| import json | |
| import os | |
| import sys | |
| import math | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import matplotlib.image as mpimg | |
| import pandas as pd | |
| # Ensure local package is importable when running in Hugging Face Spaces | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if ROOT not in sys.path: | |
| sys.path.insert(0, ROOT) | |
| from app.ml.gating import gate_signal | |
| from app.ml.inference import infer_ecg, load_model | |
| from app.rules.engine import evaluate_ecg_rules | |
| # Preload model (uses ./checkpoints/ecg_classifier.pt if present) | |
| load_model() | |
| def parse_signal(text: str | List[float]) -> List[float]: | |
| if isinstance(text, list): | |
| return [float(x) for x in text] | |
| try: | |
| return [float(x) for x in json.loads(text)] | |
| except Exception: | |
| raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]") | |
| def run_infer(signal_text: str) -> Dict[str, Any]: | |
| sig = parse_signal(signal_text) | |
| gated, gating_meta = gate_signal(sig, return_windows=True) | |
| model_output: Dict[str, Any] = infer_ecg( | |
| gated, | |
| original_len=len(sig), | |
| gating_meta=gating_meta, | |
| ) | |
| patient_context = {"patient_id": "demo"} | |
| rules_result = evaluate_ecg_rules(patient_context, model_output) | |
| explanations = [ | |
| *(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []), | |
| *rules_result.get("explanations", []), | |
| ] | |
| summary = f"Windows kept: {gating_meta.get('selected_windows',0)}/{gating_meta.get('total_windows',0)} • ratio={gating_meta.get('ratio',1):.2f}" | |
| return { | |
| "label": model_output.get("label"), | |
| "score": round(float(model_output.get("score", 0.0)), 3), | |
| "hr": model_output.get("hr"), | |
| "alert_level": rules_result.get("alert_level", "none"), | |
| "gated_ratio": round(model_output.get("gated_ratio", 1.0), 3), | |
| "gating": gating_meta, | |
| "gating_summary": summary, | |
| "explanations": explanations, | |
| } | |
| def plot_gating(signal_text: str): | |
| sig = parse_signal(signal_text) | |
| gated, meta = gate_signal(sig, return_windows=True) | |
| fig, axes = plt.subplots(2, 1, figsize=(6, 4)) | |
| axes[0].plot(sig, color="#0066ff", linewidth=1) | |
| axes[0].set_title("Raw signal") | |
| axes[1].plot(gated, color="#ff6600", linewidth=1) | |
| axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})") | |
| fig.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=120) | |
| plt.close(fig) | |
| buf.seek(0) | |
| np_img = mpimg.imread(buf) | |
| windows = meta.get("windows", []) | |
| table_rows = [] | |
| seen = set() | |
| for w in windows: | |
| key = (w.get("start"), w.get("end"), bool(w.get("selected")), bool(w.get("forced", False))) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| table_rows.append( | |
| [ | |
| w.get("start"), | |
| w.get("end"), | |
| round(float(w.get("significance", 0.0)), 3), | |
| round(float(w.get("probability", 0.0)), 3), | |
| bool(w.get("selected")), | |
| bool(w.get("forced", False)), | |
| ] | |
| ) | |
| df = pd.DataFrame(table_rows, columns=["start", "end", "significance", "prob", "selected", "forced"]) | |
| summary = f"Windows kept: {meta.get('selected_windows',0)}/{meta.get('total_windows',0)} • ratio={meta.get('ratio',1):.2f}" | |
| return np_img, summary, df | |
| # Demo signals with more structure so gating can skip/keep meaningfully | |
| demo_normal = [0.05 * math.sin(2 * math.pi * 2 * (i / 256)) for i in range(256)] | |
| demo_afib = [ | |
| 0.25 * math.sin(2 * math.pi * 6 * (i / 256)) | |
| + 0.05 * math.sin(2 * math.pi * 15 * (i / 256)) | |
| + (0.15 if i % 40 == 0 else 0.0) | |
| for i in range(256) | |
| ] | |
| demo_noise = [0.02 * math.sin(2 * math.pi * 1 * (i / 256)) + (0.01 if i % 13 == 0 else 0.0) for i in range(256)] | |
| with gr.Blocks(title="Sundew ECG Demo") as demo: | |
| gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules") | |
| with gr.Tabs(): | |
| with gr.Tab("Upload/Infer"): | |
| inp = gr.Textbox( | |
| label="ECG samples (JSON list)", | |
| value=json.dumps(demo_afib[:128]), | |
| ) | |
| out = gr.JSON(label="Inference") | |
| btn = gr.Button("Run") | |
| btn.click(run_infer, inputs=inp, outputs=out) | |
| with gr.Tab("Gating Preview"): | |
| inp2 = gr.Textbox( | |
| label="ECG samples (JSON list)", | |
| value=json.dumps(demo_afib[:128]), | |
| ) | |
| img = gr.Image(type="numpy", label="Raw vs Gated") | |
| summary_box = gr.Textbox(label="Gating summary") | |
| table = gr.Dataframe( | |
| headers=["start", "end", "significance", "prob", "selected", "forced"], | |
| datatype=["number", "number", "number", "number", "bool", "bool"], | |
| wrap=True, | |
| ) | |
| btn2 = gr.Button("Show gating") | |
| btn2.click(plot_gating, inputs=inp2, outputs=[img, summary_box, table]) | |
| with gr.Tab("Demos"): | |
| out_demo = gr.JSON() | |
| btn_n = gr.Button("Normal") | |
| btn_a = gr.Button("Arrhythmia-ish") | |
| btn_noise = gr.Button("Noisy baseline") | |
| hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False) | |
| hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False) | |
| hidden_noise = gr.Textbox(value=json.dumps(demo_noise), visible=False) | |
| btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo) | |
| btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo) | |
| btn_noise.click(run_infer, inputs=hidden_noise, outputs=out_demo) | |
| if __name__ == "__main__": | |
| demo.launch() | |