SundewAIHealth / app_space.py
mgbam's picture
Rename app/app_space.py to app_space.py
3aab745 verified
raw
history blame
3.71 kB
import io
import json
import os
import sys
from typing import Any, Dict, List
import gradio as gr
import matplotlib.pyplot as plt
# Ensure local package is importable when running in Hugging Face Spaces
ROOT = os.path.dirname(os.path.abspath(__file__))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg, load_model
from app.rules.engine import evaluate_ecg_rules
# Preload model (uses ./checkpoints/ecg_classifier.pt if present)
load_model()
def parse_signal(text: str | List[float]) -> List[float]:
if isinstance(text, list):
return [float(x) for x in text]
try:
return [float(x) for x in json.loads(text)]
except Exception:
raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]")
def run_infer(signal_text: str) -> Dict[str, Any]:
sig = parse_signal(signal_text)
gated, gating_meta = gate_signal(sig, return_windows=True)
model_output: Dict[str, Any] = infer_ecg(
gated,
original_len=len(sig),
gating_meta=gating_meta,
)
patient_context = {"patient_id": "demo"}
rules_result = evaluate_ecg_rules(patient_context, model_output)
explanations = [*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []),
*rules_result.get("explanations", [])]
return {
"label": model_output.get("label"),
"score": round(float(model_output.get("score", 0.0)), 3),
"hr": model_output.get("hr"),
"alert_level": rules_result.get("alert_level", "none"),
"gated_ratio": round(model_output.get("gated_ratio", 1.0), 3),
"gating": gating_meta,
"explanations": explanations,
}
def plot_gating(signal_text: str):
sig = parse_signal(signal_text)
gated, meta = gate_signal(sig, return_windows=True)
fig, axes = plt.subplots(2, 1, figsize=(6, 4))
axes[0].plot(sig, color="#0066ff", linewidth=1)
axes[0].set_title("Raw signal")
axes[1].plot(gated, color="#ff6600", linewidth=1)
axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120)
plt.close(fig)
buf.seek(0)
return buf
demo_normal = [0.05 for _ in range(256)]
demo_afib = [0.3 for _ in range(256)]
with gr.Blocks(title="Sundew ECG Demo") as demo:
gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules")
with gr.Tabs():
with gr.Tab("Upload/Infer"):
inp = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
out = gr.JSON(label="Inference")
btn = gr.Button("Run")
btn.click(run_infer, inputs=inp, outputs=out)
with gr.Tab("Gating Preview"):
inp2 = gr.Textbox(
label="ECG samples (JSON list)",
value=json.dumps(demo_afib[:128]),
)
img = gr.Image(type="filepath", label="Raw vs Gated")
btn2 = gr.Button("Show gating")
btn2.click(plot_gating, inputs=inp2, outputs=img)
with gr.Tab("Demos"):
out_demo = gr.JSON()
btn_n = gr.Button("Normal")
btn_a = gr.Button("Arrhythmia-ish")
hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False)
hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False)
btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo)
btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo)
if __name__ == "__main__":
demo.launch()