Spaces:
Running
Running
| import io | |
| import json | |
| import os | |
| import sys | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| # Ensure local package is importable when running in Hugging Face Spaces | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if ROOT not in sys.path: | |
| sys.path.insert(0, ROOT) | |
| from app.ml.gating import gate_signal | |
| from app.ml.inference import infer_ecg, load_model | |
| from app.rules.engine import evaluate_ecg_rules | |
| # Preload model (uses ./checkpoints/ecg_classifier.pt if present) | |
| load_model() | |
| def parse_signal(text: str | List[float]) -> List[float]: | |
| if isinstance(text, list): | |
| return [float(x) for x in text] | |
| try: | |
| return [float(x) for x in json.loads(text)] | |
| except Exception: | |
| raise gr.Error("Provide ECG samples as a JSON list, e.g., [0.1, 0.2, 0.3]") | |
| def run_infer(signal_text: str) -> Dict[str, Any]: | |
| sig = parse_signal(signal_text) | |
| gated, gating_meta = gate_signal(sig, return_windows=True) | |
| model_output: Dict[str, Any] = infer_ecg( | |
| gated, | |
| original_len=len(sig), | |
| gating_meta=gating_meta, | |
| ) | |
| patient_context = {"patient_id": "demo"} | |
| rules_result = evaluate_ecg_rules(patient_context, model_output) | |
| explanations = [*(model_output.get("gating", {}).get("explanations", []) if isinstance(model_output.get("gating"), dict) else []), | |
| *rules_result.get("explanations", [])] | |
| return { | |
| "label": model_output.get("label"), | |
| "score": round(float(model_output.get("score", 0.0)), 3), | |
| "hr": model_output.get("hr"), | |
| "alert_level": rules_result.get("alert_level", "none"), | |
| "gated_ratio": round(model_output.get("gated_ratio", 1.0), 3), | |
| "gating": gating_meta, | |
| "explanations": explanations, | |
| } | |
| def plot_gating(signal_text: str): | |
| sig = parse_signal(signal_text) | |
| gated, meta = gate_signal(sig, return_windows=True) | |
| fig, axes = plt.subplots(2, 1, figsize=(6, 4)) | |
| axes[0].plot(sig, color="#0066ff", linewidth=1) | |
| axes[0].set_title("Raw signal") | |
| axes[1].plot(gated, color="#ff6600", linewidth=1) | |
| axes[1].set_title(f"Gated signal (ratio={meta['ratio']:.2f})") | |
| fig.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=120) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return buf | |
| demo_normal = [0.05 for _ in range(256)] | |
| demo_afib = [0.3 for _ in range(256)] | |
| with gr.Blocks(title="Sundew ECG Demo") as demo: | |
| gr.Markdown("### Neurosymbolic ECG • Sundew Gating + Rules") | |
| with gr.Tabs(): | |
| with gr.Tab("Upload/Infer"): | |
| inp = gr.Textbox( | |
| label="ECG samples (JSON list)", | |
| value=json.dumps(demo_afib[:128]), | |
| ) | |
| out = gr.JSON(label="Inference") | |
| btn = gr.Button("Run") | |
| btn.click(run_infer, inputs=inp, outputs=out) | |
| with gr.Tab("Gating Preview"): | |
| inp2 = gr.Textbox( | |
| label="ECG samples (JSON list)", | |
| value=json.dumps(demo_afib[:128]), | |
| ) | |
| img = gr.Image(type="filepath", label="Raw vs Gated") | |
| btn2 = gr.Button("Show gating") | |
| btn2.click(plot_gating, inputs=inp2, outputs=img) | |
| with gr.Tab("Demos"): | |
| out_demo = gr.JSON() | |
| btn_n = gr.Button("Normal") | |
| btn_a = gr.Button("Arrhythmia-ish") | |
| hidden_n = gr.Textbox(value=json.dumps(demo_normal), visible=False) | |
| hidden_a = gr.Textbox(value=json.dumps(demo_afib), visible=False) | |
| btn_n.click(run_infer, inputs=hidden_n, outputs=out_demo) | |
| btn_a.click(run_infer, inputs=hidden_a, outputs=out_demo) | |
| if __name__ == "__main__": | |
| demo.launch() | |