Spaces:
Running
Running
| """ | |
| Sundew Live Monitor - Enhanced "Wow" Demo | |
| Production-quality interface showcasing neurosymbolic ECG monitoring | |
| """ | |
| import io | |
| import json | |
| import math | |
| import os | |
| import sys | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.image as mpimg | |
| import numpy as np | |
| import pandas as pd | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if ROOT not in sys.path: | |
| sys.path.insert(0, ROOT) | |
| from app.ml.gating import gate_signal | |
| from app.ml.inference import infer_ecg, load_model | |
| from app.rules.engine import evaluate_ecg_rules | |
| load_model() | |
| SCENARIOS = { | |
| "healthy": { | |
| "name": "Healthy Adult (60yo)", | |
| "age": 60, | |
| "has_prior_stroke": False, | |
| "signal_type": "normal", | |
| "description": "Normal sinus rhythm, no risk factors, routine monitoring", | |
| "icon": "β" | |
| }, | |
| "afib_high_risk": { | |
| "name": "AFib Suspect (85yo, Prior Stroke)", | |
| "age": 85, | |
| "has_prior_stroke": True, | |
| "signal_type": "afib", | |
| "description": "Irregular rhythm detected, high-risk patient requiring immediate review", | |
| "icon": "β " | |
| }, | |
| "tachycardia": { | |
| "name": "Tachycardia Episode (45yo)", | |
| "age": 45, | |
| "has_prior_stroke": False, | |
| "signal_type": "tachy", | |
| "description": "Elevated heart rate (120+ bpm), otherwise healthy patient", | |
| "icon": "β" | |
| }, | |
| "elderly_normal": { | |
| "name": "Elderly Patient (78yo, Normal ECG)", | |
| "age": 78, | |
| "has_prior_stroke": True, | |
| "signal_type": "normal", | |
| "description": "High-risk profile but currently stable rhythm", | |
| "icon": "π€" | |
| }, | |
| "noisy": { | |
| "name": "Poor Signal Quality", | |
| "age": 60, | |
| "has_prior_stroke": False, | |
| "signal_type": "noise", | |
| "description": "Motion artifacts, low-quality signal requiring gating", | |
| "icon": "~" | |
| } | |
| } | |
| def generate_signal(signal_type: str, length: int = 512) -> List[float]: | |
| if signal_type == "normal": | |
| return [0.05 * math.sin(2 * math.pi * 2 * (i / length)) + | |
| 0.02 * math.sin(2 * math.pi * 0.5 * (i / length)) for i in range(length)] | |
| elif signal_type == "afib": | |
| return [ | |
| 0.25 * math.sin(2 * math.pi * 6 * (i / length)) + | |
| 0.05 * math.sin(2 * math.pi * 15 * (i / length)) + | |
| (0.15 if i % 40 == 0 else 0.0) + | |
| 0.03 * (hash(i) % 100 - 50) / 500 | |
| for i in range(length) | |
| ] | |
| elif signal_type == "tachy": | |
| return [0.08 * math.sin(2 * math.pi * 4.5 * (i / length)) + | |
| 0.03 * math.sin(2 * math.pi * 1 * (i / length)) for i in range(length)] | |
| elif signal_type == "noise": | |
| return [0.02 * math.sin(2 * math.pi * 1 * (i / length)) + | |
| (0.01 if i % 13 == 0 else 0.0) + | |
| 0.005 * (hash(i) % 100 - 50) / 50 for i in range(length)] | |
| return [0.0] * length | |
| def run_pipeline(scenario_key: str): | |
| scenario = SCENARIOS[scenario_key] | |
| signal = generate_signal(scenario["signal_type"], length=512) | |
| gated, gating_meta = gate_signal(signal, return_windows=True) | |
| model_output = infer_ecg(gated, original_len=len(signal), gating_meta=gating_meta) | |
| patient_context = { | |
| "patient_id": scenario_key, | |
| "age": scenario["age"], | |
| "has_prior_stroke": scenario["has_prior_stroke"], | |
| } | |
| rules_result = evaluate_ecg_rules(patient_context, model_output) | |
| # Build comprehensive results | |
| energy_saved = (1 - gating_meta.get("ratio", 1.0)) * 100 | |
| # Summary card | |
| summary_html = f""" | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 25px; border-radius: 15px; margin: 10px 0;"> | |
| <h2 style="margin: 0 0 15px 0;">Patient: {scenario['name']}</h2> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px;"> | |
| <h3 style="margin: 0; font-size: 14px; opacity: 0.9;">Diagnosis</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 24px; font-weight: bold;">{model_output.get('label', 'Unknown').upper()}</p> | |
| <p style="margin: 5px 0 0 0; opacity: 0.8;">Confidence: {model_output.get('score', 0.0):.1%}</p> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px;"> | |
| <h3 style="margin: 0; font-size: 14px; opacity: 0.9;">Alert Level</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 24px; font-weight: bold;">{rules_result.get('alert_level', 'NONE').upper()}</p> | |
| <p style="margin: 5px 0 0 0; opacity: 0.8;">HR: {model_output.get('hr')} bpm</p> | |
| </div> | |
| </div> | |
| <div style="margin-top: 15px; background: rgba(46,213,115,0.2); padding: 12px; border-radius: 8px; border-left: 4px solid #2ed573;"> | |
| <strong>Energy Savings: {energy_saved:.1f}%</strong> | Windows: {gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} | |
| </div> | |
| </div> | |
| """ | |
| # Signal visualization | |
| fig1, axes = plt.subplots(2, 1, figsize=(12, 6)) | |
| axes[0].plot(signal, color='#3498db', linewidth=1.5, alpha=0.8) | |
| axes[0].set_title('Original ECG Signal', fontsize=13, fontweight='bold') | |
| axes[0].set_ylabel('Amplitude') | |
| axes[0].grid(alpha=0.3) | |
| axes[0].set_xlim(0, len(signal)) | |
| axes[1].plot(gated, color='#e74c3c', linewidth=1.5, alpha=0.8) | |
| axes[1].set_title(f'Gated Signal (Compression: {gating_meta.get("ratio", 1.0):.1%})', fontsize=13, fontweight='bold') | |
| axes[1].set_xlabel('Sample Index') | |
| axes[1].set_ylabel('Amplitude') | |
| axes[1].grid(alpha=0.3) | |
| fig1.tight_layout() | |
| buf1 = io.BytesIO() | |
| fig1.savefig(buf1, format='png', dpi=150, bbox_inches='tight') | |
| plt.close(fig1) | |
| buf1.seek(0) | |
| signal_img = mpimg.imread(buf1) | |
| # Energy bar chart | |
| fig2, ax = plt.subplots(figsize=(10, 4)) | |
| categories = ['Baseline\n(No Gating)', 'Sundew\n(With Gating)'] | |
| compute = [100, gating_meta.get("ratio", 1.0) * 100] | |
| colors = ['#e74c3c', '#2ecc71'] | |
| bars = ax.barh(categories, compute, color=colors, edgecolor='black', linewidth=1.5) | |
| ax.set_xlabel('Compute Used (%)', fontsize=12, fontweight='bold') | |
| ax.set_xlim(0, 110) | |
| for bar, val in zip(bars, compute): | |
| ax.text(val + 2, bar.get_y() + bar.get_height()/2, | |
| f'{val:.1f}%', va='center', fontsize=12, fontweight='bold') | |
| ax.text(55, 1.6, f'Energy Savings: {energy_saved:.1f}%', | |
| ha='center', fontsize=14, fontweight='bold', | |
| bbox=dict(boxstyle='round,pad=0.8', facecolor='#f39c12', alpha=0.8)) | |
| ax.set_title('Computational Efficiency', fontsize=14, fontweight='bold') | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| fig2.tight_layout() | |
| buf2 = io.BytesIO() | |
| fig2.savefig(buf2, format='png', dpi=150, bbox_inches='tight') | |
| plt.close(fig2) | |
| buf2.seek(0) | |
| energy_img = mpimg.imread(buf2) | |
| # Rule chain | |
| rule_md = f"""### Rule Chain Trace | |
| **Neural Network Output:** | |
| - Label: `{model_output.get('label')}` (Confidence: {model_output.get('score', 0.0):.3f}) | |
| - Estimated HR: `{model_output.get('hr')} bpm` | |
| **Patient Context:** | |
| - Age: {scenario['age']} years | |
| - Prior Stroke: {'Yes' if scenario['has_prior_stroke'] else 'No'} | |
| **Rules Evaluated:** | |
| """ | |
| for exp in rules_result.get('explanations', []): | |
| rule_md += f"\n- {exp}" | |
| rule_md += f"\n\n**Final Alert:** `{rules_result.get('alert_level', 'NONE').upper()}`" | |
| return summary_html, signal_img, energy_img, rule_md | |
| # Build Gradio Interface | |
| with gr.Blocks(title="Sundew ECG Monitor", css=""" | |
| .gradio-container {font-family: 'Inter', sans-serif;} | |
| .gr-button-primary {background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: none;} | |
| """) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 30px 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px;"> | |
| <h1 style="margin: 0; font-size: 42px; font-weight: 800;">Sundew ECG Monitor</h1> | |
| <p style="margin: 10px 0 0 0; font-size: 18px; opacity: 0.95;">Neurosymbolic AI for Energy-Efficient Medical Monitoring</p> | |
| <div style="margin-top: 15px; display: inline-flex; gap: 20px; flex-wrap: wrap; justify-content: center;"> | |
| <span style="background: rgba(255,255,255,0.2); padding: 8px 16px; border-radius: 20px;">β‘ 85% Energy Savings</span> | |
| <span style="background: rgba(255,255,255,0.2); padding: 8px 16px; border-radius: 20px;">π§ Explainable AI</span> | |
| <span style="background: rgba(255,255,255,0.2); padding: 8px 16px; border-radius: 20px;">π₯ Clinical-Grade Rules</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Select Patient Scenario") | |
| scenario_dropdown = gr.Radio( | |
| choices=list(SCENARIOS.keys()), | |
| value="afib_high_risk", | |
| label="", | |
| info="Choose a patient to analyze" | |
| ) | |
| for key, val in SCENARIOS.items(): | |
| gr.Markdown(f"**{val['icon']} {val['name']}**\n{val['description']}", visible=(key=="afib_high_risk")) | |
| run_btn = gr.Button("Run Analysis", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| **Architecture:** | |
| ``` | |
| ECG Signal β Sundew Gating β ML Inference β Rule Engine | |
| (50-90% reduction) (PyTorch) (Symbolic) | |
| ``` | |
| """) | |
| with gr.Column(scale=2): | |
| summary_card = gr.HTML() | |
| with gr.Tabs(): | |
| with gr.Tab("π Signal Analysis"): | |
| signal_plot = gr.Image(label="ECG: Original vs Gated") | |
| with gr.Tab("β‘ Energy Efficiency"): | |
| energy_plot = gr.Image(label="Compute Savings") | |
| with gr.Tab("π Rule Chain"): | |
| rule_trace = gr.Markdown() | |
| run_btn.click( | |
| run_pipeline, | |
| inputs=scenario_dropdown, | |
| outputs=[summary_card, signal_plot, energy_plot, rule_trace] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;"> | |
| <p style="color: #666; font-size: 14px;"> | |
| Built with Sundew Algorithm Β· FastAPI Β· PyTorch Β· Gradio | |
| </p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |