SundewAIHealth / app_space.py
mgbam's picture
Upload app_space.py
6b447b8 verified
raw
history blame
5.77 kB
import io
import json
import os
import sys
import math
from typing import Any, Dict, List
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg
import pandas as pd
# Ensure local package is importable when running in Hugging Face Spaces
ROOT = os.path.dirname(os.path.abspath(__file__))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg, load_model
from app.rules.engine import evaluate_ecg_rules
# Preload model (uses ./checkpoints/ecg_classifier.pt if present)
load_model()
def parse_signal(text: str | List[float]) -> List[float]:
if isinstance(text, list):
return [float(x) for x in text]
try:
return [float(x) for x in json.loads(text)]
except Exception:
raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]")
def run_infer(signal_text: str) -> Dict[str, Any]:
sig = parse_signal(signal_text)
gated, gating_meta = gate_signal(sig, return_windows=True)
model_output: Dict[str, Any] = infer_ecg(
gated,
original_len=len(sig),
gating_meta=gating_meta,
)
patient_context = {"patient_id": "demo"}
rules_result = evaluate_ecg_rules(patient_context, model_output)
explanations = [
*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []),
*rules_result.get("explanations", []),
]
summary = f"Windows kept: {gating_meta.get('selected_windows',0)}/{gating_meta.get('total_windows',0)} • ratio={gating_meta.get('ratio',1):.2f}"
return {
"label": model_output.get("label"),
"score": round(float(model_output.get("score", 0.0)), 3),
"hr": model_output.get("hr"),
"alert_level": rules_result.get("alert_level", "none"),
"gated_ratio": round(model_output.get("gated_ratio", 1.0), 3),
"gating": gating_meta,
"gating_summary": summary,
"explanations": explanations,
}
def plot_gating(signal_text: str):
sig = parse_signal(signal_text)
gated, meta = gate_signal(sig, return_windows=True)
fig, axes = plt.subplots(2, 1, figsize=(6, 4))
axes[0].plot(sig, color="#0066ff", linewidth=1)
axes[0].set_title("Raw signal")
axes[1].plot(gated, color="#ff6600", linewidth=1)
axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
np_img = mpimg.imread(buf)
windows = meta.get("windows", [])
table_rows = []
seen = set()
for w in windows:
key = (w.get("start"), w.get("end"), bool(w.get("selected")), bool(w.get("forced", False)))
if key in seen:
continue
seen.add(key)
table_rows.append(
[
w.get("start"),
w.get("end"),
round(float(w.get("significance", 0.0)), 3),
round(float(w.get("probability", 0.0)), 3),
bool(w.get("selected")),
bool(w.get("forced", False)),
]
)
df = pd.DataFrame(table_rows, columns=["start", "end", "significance", "prob", "selected", "forced"])
summary = f"Windows kept: {meta.get('selected_windows',0)}/{meta.get('total_windows',0)} • ratio={meta.get('ratio',1):.2f}"
return np_img, summary, df
# Demo signals with more structure so gating can skip/keep meaningfully
demo_normal = [0.05 * math.sin(2 * math.pi * 2 * (i / 256)) for i in range(256)]
demo_afib = [
0.25 * math.sin(2 * math.pi * 6 * (i / 256))
+ 0.05 * math.sin(2 * math.pi * 15 * (i / 256))
+ (0.15 if i % 40 == 0 else 0.0)
for i in range(256)
]
demo_noise = [0.02 * math.sin(2 * math.pi * 1 * (i / 256)) + (0.01 if i % 13 == 0 else 0.0) for i in range(256)]
with gr.Blocks(title="Sundew ECG Demo") as demo:
gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules")
with gr.Tabs():
with gr.Tab("Upload/Infer"):
inp = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
out = gr.JSON(label="Inference")
btn = gr.Button("Run")
btn.click(run_infer, inputs=inp, outputs=out)
with gr.Tab("Gating Preview"):
inp2 = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
img = gr.Image(type="numpy", label="Raw vs Gated")
summary_box = gr.Textbox(label="Gating summary")
table = gr.Dataframe(
headers=["start", "end", "significance", "prob", "selected", "forced"],
datatype=["number", "number", "number", "number", "bool", "bool"],
wrap=True,
)
btn2 = gr.Button("Show gating")
btn2.click(plot_gating, inputs=inp2, outputs=[img, summary_box, table])
with gr.Tab("Demos"):
out_demo = gr.JSON()
btn_n = gr.Button("Normal")
btn_a = gr.Button("Arrhythmia-ish")
btn_noise = gr.Button("Noisy baseline")
hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False)
hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False)
hidden_noise = gr.Textbox(value=json.dumps(demo_noise), visible=False)
btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo)
btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo)
btn_noise.click(run_infer, inputs=hidden_noise, outputs=out_demo)
if __name__ == "__main__":
demo.launch()