Spaces:
Running
Running
File size: 5,765 Bytes
e47d4ee 82e5439 a003746 e47d4ee a003746 d21e3e5 e47d4ee 82e5439 e47d4ee a003746 e47d4ee a003746 e47d4ee 3a42f68 a003746 d21e3e5 6b447b8 d21e3e5 6b447b8 d21e3e5 a003746 d21e3e5 a003746 e47d4ee 457c8fa a003746 d21e3e5 e47d4ee a003746 e47d4ee a003746 e47d4ee a003746 e47d4ee a003746 e47d4ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import io
import json
import os
import sys
import math
from typing import Any, Dict, List
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg
import pandas as pd
# Ensure local package is importable when running in Hugging Face Spaces
ROOT = os.path.dirname(os.path.abspath(__file__))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg, load_model
from app.rules.engine import evaluate_ecg_rules
# Preload model (uses ./checkpoints/ecg_classifier.pt if present)
load_model()
def parse_signal(text: str | List[float]) -> List[float]:
if isinstance(text, list):
return [float(x) for x in text]
try:
return [float(x) for x in json.loads(text)]
except Exception:
raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]")
def run_infer(signal_text: str) -> Dict[str, Any]:
sig = parse_signal(signal_text)
gated, gating_meta = gate_signal(sig, return_windows=True)
model_output: Dict[str, Any] = infer_ecg(
gated,
original_len=len(sig),
gating_meta=gating_meta,
)
patient_context = {"patient_id": "demo"}
rules_result = evaluate_ecg_rules(patient_context, model_output)
explanations = [
*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []),
*rules_result.get("explanations", []),
]
summary = f"Windows kept: {gating_meta.get('selected_windows',0)}/{gating_meta.get('total_windows',0)} • ratio={gating_meta.get('ratio',1):.2f}"
return {
"label": model_output.get("label"),
"score": round(float(model_output.get("score", 0.0)), 3),
"hr": model_output.get("hr"),
"alert_level": rules_result.get("alert_level", "none"),
"gated_ratio": round(model_output.get("gated_ratio", 1.0), 3),
"gating": gating_meta,
"gating_summary": summary,
"explanations": explanations,
}
def plot_gating(signal_text: str):
sig = parse_signal(signal_text)
gated, meta = gate_signal(sig, return_windows=True)
fig, axes = plt.subplots(2, 1, figsize=(6, 4))
axes[0].plot(sig, color="#0066ff", linewidth=1)
axes[0].set_title("Raw signal")
axes[1].plot(gated, color="#ff6600", linewidth=1)
axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
np_img = mpimg.imread(buf)
windows = meta.get("windows", [])
table_rows = []
seen = set()
for w in windows:
key = (w.get("start"), w.get("end"), bool(w.get("selected")), bool(w.get("forced", False)))
if key in seen:
continue
seen.add(key)
table_rows.append(
[
w.get("start"),
w.get("end"),
round(float(w.get("significance", 0.0)), 3),
round(float(w.get("probability", 0.0)), 3),
bool(w.get("selected")),
bool(w.get("forced", False)),
]
)
df = pd.DataFrame(table_rows, columns=["start", "end", "significance", "prob", "selected", "forced"])
summary = f"Windows kept: {meta.get('selected_windows',0)}/{meta.get('total_windows',0)} • ratio={meta.get('ratio',1):.2f}"
return np_img, summary, df
# Demo signals with more structure so gating can skip/keep meaningfully
demo_normal = [0.05 * math.sin(2 * math.pi * 2 * (i / 256)) for i in range(256)]
demo_afib = [
0.25 * math.sin(2 * math.pi * 6 * (i / 256))
+ 0.05 * math.sin(2 * math.pi * 15 * (i / 256))
+ (0.15 if i % 40 == 0 else 0.0)
for i in range(256)
]
demo_noise = [0.02 * math.sin(2 * math.pi * 1 * (i / 256)) + (0.01 if i % 13 == 0 else 0.0) for i in range(256)]
with gr.Blocks(title="Sundew ECG Demo") as demo:
gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules")
with gr.Tabs():
with gr.Tab("Upload/Infer"):
inp = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
out = gr.JSON(label="Inference")
btn = gr.Button("Run")
btn.click(run_infer, inputs=inp, outputs=out)
with gr.Tab("Gating Preview"):
inp2 = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
img = gr.Image(type="numpy", label="Raw vs Gated")
summary_box = gr.Textbox(label="Gating summary")
table = gr.Dataframe(
headers=["start", "end", "significance", "prob", "selected", "forced"],
datatype=["number", "number", "number", "number", "bool", "bool"],
wrap=True,
)
btn2 = gr.Button("Show gating")
btn2.click(plot_gating, inputs=inp2, outputs=[img, summary_box, table])
with gr.Tab("Demos"):
out_demo = gr.JSON()
btn_n = gr.Button("Normal")
btn_a = gr.Button("Arrhythmia-ish")
btn_noise = gr.Button("Noisy baseline")
hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False)
hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False)
hidden_noise = gr.Textbox(value=json.dumps(demo_noise), visible=False)
btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo)
btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo)
btn_noise.click(run_infer, inputs=hidden_noise, outputs=out_demo)
if __name__ == "__main__":
demo.launch()
|