File size: 3,507 Bytes
b63d191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import io
import json
from typing import Any, Dict, List

import gradio as gr
import matplotlib.pyplot as plt

from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg, load_model
from app.rules.engine import evaluate_ecg_rules


# Preload model (uses ./checkpoints/ecg_classifier.pt if present)
load_model()


def parse_signal(text: str | List[float]) -> List[float]:
    if isinstance(text, list):
        return [float(x) for x in text]
    try:
        return [float(x) for x in json.loads(text)]
    except Exception:
        raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]")


def run_infer(signal_text: str) -> Dict[str, Any]:
    sig = parse_signal(signal_text)
    gated, gating_meta = gate_signal(sig, return_windows=True)
    model_output: Dict[str, Any] = infer_ecg(
        gated,
        original_len=len(sig),
        gating_meta=gating_meta,
    )
    patient_context = {"patient_id": "demo"}
    rules_result = evaluate_ecg_rules(patient_context, model_output)
    explanations = [*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []),
                    *rules_result.get("explanations", [])]
    return {
        "label": model_output.get("label"),
        "score": round(float(model_output.get("score", 0.0)), 3),
        "hr": model_output.get("hr"),
        "alert_level": rules_result.get("alert_level", "none"),
        "gated_ratio": round(model_output.get("gated_ratio", 1.0), 3),
        "gating": gating_meta,
        "explanations": explanations,
    }


def plot_gating(signal_text: str):
    sig = parse_signal(signal_text)
    gated, meta = gate_signal(sig, return_windows=True)
    fig, axes = plt.subplots(2, 1, figsize=(6, 4))
    axes[0].plot(sig, color="#0066ff", linewidth=1)
    axes[0].set_title("Raw signal")
    axes[1].plot(gated, color="#ff6600", linewidth=1)
    axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})")
    fig.tight_layout()
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=120)
    plt.close(fig)
    buf.seek(0)
    return buf


demo_normal = [0.05 for _ in range(256)]
demo_afib = [0.3 for _ in range(256)]

with gr.Blocks(title="Sundew ECG Demo") as demo:
    gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules")
    with gr.Tabs():
        with gr.Tab("Upload/Infer"):
            inp = gr.Textbox(
                label="ECG samples (JSON list)",
                value=json.dumps(demo_afib[:128]),
            )
            out = gr.JSON(label="Inference")
            btn = gr.Button("Run")
            btn.click(run_infer, inputs=inp, outputs=out)
        with gr.Tab("Gating Preview"):
            inp2 = gr.Textbox(
                label="ECG samples (JSON list)",
                value=json.dumps(demo_afib[:128]),
            )
            img = gr.Image(type="filepath", label="Raw vs Gated")
            btn2 = gr.Button("Show gating")
            btn2.click(plot_gating, inputs=inp2, outputs=img)
        with gr.Tab("Demos"):
            out_demo = gr.JSON()
            btn_n = gr.Button("Normal")
            btn_a = gr.Button("Arrhythmia-ish")
            hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False)
            hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False)
            btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo)
            btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo)


if __name__ == "__main__":
    demo.launch()