HeartWatch AI
AI-Powered 12-Lead ECG Analysis
""" HeartWatch AI - ECG Analysis Demo ================================== A Gradio-based web application for AI-powered ECG analysis using DeepECG models. Features: - 77-class ECG diagnosis - LVEF < 40% prediction - LVEF < 50% prediction - 5-year AFib risk assessment - Interactive 12-lead ECG visualization """ import os import logging import numpy as np import gradio as gr from pathlib import Path # Local imports from inference import DeepECGInference from visualization import ( plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global inference engine inference_engine = None # Sample ECG descriptions - mapped by file stem (with underscores replaced by spaces and title-cased) # The files are: Atrial_Flutter.npy, Normal_Sinus_Rhythm.npy, Ventricular_Tachycardia.npy # They get sorted alphabetically: Atrial Flutter, Normal Sinus Rhythm, Ventricular Tachycardia # We want to display them as Sample 1, Sample 2, Sample 3 SAMPLE_FILE_TO_DISPLAY = { "Atrial Flutter": "Sample 1", "Normal Sinus Rhythm": "Sample 2", "Ventricular Tachycardia": "Sample 3", } SAMPLE_DESCRIPTIONS = { "Sample 1": "Atrial Flutter - A rapid but regular atrial rhythm, typically around 250-350 bpm in the atria.", "Sample 2": "Normal Sinus Rhythm - A healthy heart rhythm with regular beats originating from the sinus node.", "Sample 3": "Ventricular Tachycardia - A fast heart rhythm originating from the ventricles, potentially life-threatening.", } # Reverse mapping: display name to real condition info for analysis results DISPLAY_TO_CONDITION = { "Sample 1": { "name": "Atrial Flutter", "description": "A rapid but regular atrial rhythm, typically around 250-350 bpm in the atria." }, "Sample 2": { "name": "Normal Sinus Rhythm", "description": "A healthy heart rhythm with regular beats originating from the sinus node." }, "Sample 3": { "name": "Ventricular Tachycardia", "description": "A fast heart rhythm originating from the ventricles, potentially life-threatening." }, } def load_inference_engine(): """Load the inference engine on startup.""" global inference_engine if inference_engine is None: logger.info("Loading DeepECG inference engine...") inference_engine = DeepECGInference() inference_engine.load_models() logger.info("Inference engine loaded successfully") return inference_engine def get_sample_ecgs(): """Get list of sample ECG files from demo_data directory.""" sample_dir = Path(__file__).parent / "demo_data" / "samples" if not sample_dir.exists(): logger.warning(f"Sample directory not found: {sample_dir}") return [] samples = [] for npy_file in sorted(sample_dir.glob("*.npy")): original_name = npy_file.stem.replace("_", " ").title() # Map to new display name (Sample 1, Sample 2, Sample 3) display_name = SAMPLE_FILE_TO_DISPLAY.get(original_name, original_name) samples.append({ "path": str(npy_file), "name": display_name, "original_name": original_name, "description": SAMPLE_DESCRIPTIONS.get(display_name, "Sample ECG recording") }) logger.info(f"Found {len(samples)} sample ECGs") return samples def analyze_ecg(ecg_signal: np.ndarray, filename: str = "ECG Analysis", condition_info: dict = None): """ Analyze an ECG signal and return all visualizations. Args: ecg_signal: ECG signal array filename: Name to display condition_info: Optional dict with 'name' and 'description' for the condition Returns: Tuple of (ecg_plot, diagnosis_plot, risk_plot, summary_text) """ engine = load_inference_engine() # Run inference results = engine.predict(ecg_signal) # Generate ECG waveform plot ecg_fig = plot_ecg_waveform(ecg_signal, sample_rate=250, title=filename) # Generate diagnosis bar chart if "diagnosis_77" in results: probs = results["diagnosis_77"]["probabilities"] class_names = results["diagnosis_77"]["class_names"] diagnosis_dict = dict(zip(class_names, probs)) diagnosis_fig = plot_diagnosis_bars(diagnosis_dict, top_n=10) else: diagnosis_fig = None # Generate risk gauges lvef_40 = results.get("lvef_40", 0.0) lvef_50 = results.get("lvef_50", 0.0) afib_5y = results.get("afib_5y", 0.0) risk_fig = plot_risk_gauges(lvef_40, lvef_50, afib_5y) # Generate modern HTML summary with styled diagnosis cards inference_time = results.get("inference_time_ms", 0) # Build the diagnosis cards HTML with modern dark theme design diagnosis_html = '
{condition_desc}
' if condition_desc else "" else: display_title = filename condition_html = "" summary = f"""Please upload a .npy file containing ECG data.
" try: # In Gradio 4.x with type="filepath", file is a string path file_path = file if isinstance(file, str) else file.name ecg_signal = np.load(file_path) filename = Path(file_path).stem.replace("_", " ").title() return analyze_ecg(ecg_signal, filename) except Exception as e: logger.error(f"Error loading file: {e}") return None, None, None, f"Error loading file: {str(e)}
" def analyze_sample_by_name(sample_name: str): """Analyze a sample ECG by its name.""" if not sample_name: return None, None, None, "Please select a sample ECG.
" samples = get_sample_ecgs() for sample in samples: if sample["name"] == sample_name: try: ecg_signal = np.load(sample["path"]) # Get the real condition info for display condition_info = DISPLAY_TO_CONDITION.get(sample_name) return analyze_ecg(ecg_signal, sample["name"], condition_info) except Exception as e: logger.error(f"Error loading sample: {e}") return None, None, None, f"Error loading sample: {str(e)}
" return None, None, None, "Sample not found.
" def create_demo_interface(): """Create the Gradio interface.""" # Get samples at startup samples = get_sample_ecgs() sample_names = [s["name"] for s in samples] # Custom CSS for styling with modern animated header custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } /* Animated Header Styles */ .main-header { text-align: center; padding: 40px 24px; background: linear-gradient(-45deg, #ee7752, #e73c7e, #c0392b, #e74c3c); background-size: 400% 400%; animation: gradientShift 8s ease infinite; color: white; border-radius: 16px; margin-bottom: 24px; box-shadow: 0 10px 40px rgba(231, 76, 60, 0.3); position: relative; overflow: hidden; } .main-header::before { content: ''; position: absolute; top: 0; left: 0; right: 0; bottom: 0; background: radial-gradient(circle at 30% 50%, rgba(255,255,255,0.1) 0%, transparent 50%); pointer-events: none; } @keyframes gradientShift { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } .header-content { position: relative; z-index: 2; display: flex; flex-direction: column; align-items: center; gap: 12px; } /* Pulsing Heart Container */ .heart-container { position: relative; width: 100px; height: 100px; display: flex; align-items: center; justify-content: center; } /* Heart SVG Animation */ .heart-svg { width: 80px; height: 80px; animation: heartbeat 1.2s ease-in-out infinite; filter: drop-shadow(0 0 20px rgba(255,255,255,0.5)); } @keyframes heartbeat { 0% { transform: scale(1); } 14% { transform: scale(1.15); } 28% { transform: scale(1); } 42% { transform: scale(1.1); } 70% { transform: scale(1); } } /* ECG Line Animation */ .ecg-line { position: absolute; width: 200px; height: 40px; left: 50%; transform: translateX(-50%); bottom: -10px; } .ecg-path { stroke: rgba(255,255,255,0.8); stroke-width: 2; fill: none; stroke-linecap: round; stroke-dasharray: 200; stroke-dashoffset: 200; animation: ecgDraw 2s ease-in-out infinite; } @keyframes ecgDraw { 0% { stroke-dashoffset: 200; opacity: 0; } 10% { opacity: 1; } 50% { stroke-dashoffset: 0; opacity: 1; } 90% { opacity: 1; } 100% { stroke-dashoffset: -200; opacity: 0; } } .main-header h1 { margin: 0; font-size: 2.8em; font-weight: 700; letter-spacing: -0.02em; text-shadow: 0 2px 10px rgba(0,0,0,0.2); } .main-header p { margin: 0; opacity: 0.95; font-size: 1.2em; font-weight: 400; letter-spacing: 0.02em; } .sample-card { padding: 16px; border-radius: 8px; background: #f8f9fa; margin: 8px 0; border-left: 4px solid #e74c3c; } .quick-start { background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%); padding: 18px 20px; border-radius: 12px; margin: 20px 0; border-left: 5px solid #4caf50; box-shadow: 0 2px 8px rgba(76, 175, 80, 0.15); } /* Dark Theme Diagnosis Dashboard */ .diagnosis-dashboard { background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); border-radius: 16px; padding: 24px; margin-top: 8px; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3), inset 0 1px 0 rgba(255, 255, 255, 0.05); font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } .diagnosis-dashboard-title { color: #ffffff; font-size: 0.85em; font-weight: 600; text-transform: uppercase; letter-spacing: 0.1em; margin-bottom: 20px; padding-bottom: 12px; border-bottom: 1px solid rgba(255, 255, 255, 0.1); text-shadow: 0 0 20px rgba(255, 255, 255, 0.3); } .diagnosis-row { display: flex; align-items: center; padding: 12px 16px; margin: 8px 0; background: rgba(255, 255, 255, 0.03); border-radius: 10px; transition: all 0.2s ease; border: 1px solid rgba(255, 255, 255, 0.05); } .diagnosis-row:hover { background: rgba(255, 255, 255, 0.08); transform: translateX(4px); } .diagnosis-rank { font-size: 0.9em; font-weight: 700; color: rgba(255, 255, 255, 0.5); width: 36px; flex-shrink: 0; } .diagnosis-name { font-size: 0.95em; font-weight: 500; color: #ffffff; min-width: 120px; max-width: 180px; flex-shrink: 0; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; text-shadow: 0 1px 2px rgba(0, 0, 0, 0.3); } .diagnosis-bar-container { flex: 1; display: flex; align-items: center; margin: 0 16px; min-width: 80px; } .diagnosis-bar-track { width: 100%; height: 6px; background: rgba(255, 255, 255, 0.1); border-radius: 3px; position: relative; overflow: hidden; } .diagnosis-bar-fill { height: 100%; border-radius: 3px; transition: width 0.6s cubic-bezier(0.4, 0, 0.2, 1); position: relative; } /* Animated shine effect on bars */ .diagnosis-bar-fill::after { content: ''; position: absolute; top: 0; left: 0; right: 0; bottom: 0; background: linear-gradient( 90deg, transparent 0%, rgba(255, 255, 255, 0.3) 50%, transparent 100% ); animation: shine 2s ease-in-out infinite; } @keyframes shine { 0% { transform: translateX(-100%); } 100% { transform: translateX(100%); } } .diagnosis-percent { font-size: 0.9em; font-weight: 700; width: 55px; text-align: right; flex-shrink: 0; text-shadow: 0 0 10px currentColor; } /* Color classes for severity with glow effects */ .severity-low .diagnosis-bar-fill { background: linear-gradient(90deg, #00c853 0%, #69f0ae 100%); box-shadow: 0 0 12px rgba(0, 200, 83, 0.5), 0 0 4px rgba(0, 200, 83, 0.3); } .severity-low .diagnosis-percent { color: #69f0ae; } .severity-medium .diagnosis-bar-fill { background: linear-gradient(90deg, #ff9800 0%, #ffc107 100%); box-shadow: 0 0 12px rgba(255, 152, 0, 0.5), 0 0 4px rgba(255, 152, 0, 0.3); } .severity-medium .diagnosis-percent { color: #ffc107; } .severity-high .diagnosis-bar-fill { background: linear-gradient(90deg, #f44336 0%, #ff5252 100%); box-shadow: 0 0 12px rgba(244, 67, 54, 0.5), 0 0 4px rgba(244, 67, 54, 0.3); } .severity-high .diagnosis-percent { color: #ff5252; } /* Responsive Design for Diagnosis Dashboard */ @media (max-width: 768px) { .diagnosis-dashboard { padding: 16px; border-radius: 12px; } .diagnosis-row { padding: 10px 12px; flex-wrap: wrap; } .diagnosis-rank { width: 28px; font-size: 0.85em; } .diagnosis-name { flex: 1; min-width: 100px; max-width: none; font-size: 0.9em; } .diagnosis-bar-container { order: 3; width: 100%; margin: 8px 0 0 0; flex-basis: 100%; } .diagnosis-percent { width: auto; margin-left: auto; font-size: 0.85em; } } @media (max-width: 480px) { .diagnosis-dashboard { padding: 12px; margin-top: 4px; } .diagnosis-dashboard-title { font-size: 0.75em; margin-bottom: 12px; padding-bottom: 8px; } .diagnosis-row { padding: 8px 10px; margin: 6px 0; } .diagnosis-rank { width: 24px; font-size: 0.8em; } .diagnosis-name { font-size: 0.85em; } .diagnosis-percent { font-size: 0.8em; } .diagnosis-bar-track { height: 5px; } } /* Footer Styles */ .footer-container { margin-top: 40px; padding: 30px; background: linear-gradient(135deg, #2c3e50 0%, #1a252f 100%); border-radius: 16px; color: white; text-align: center; } .footer-content { max-width: 800px; margin: 0 auto; } .footer-acknowledgement { font-size: 1em; margin-bottom: 16px; padding-bottom: 16px; border-bottom: 1px solid rgba(255,255,255,0.2); } .footer-acknowledgement a { color: #3498db; text-decoration: none; font-weight: 600; } .footer-acknowledgement a:hover { text-decoration: underline; } .footer-disclaimer { font-size: 0.9em; color: rgba(255,255,255,0.7); padding: 12px 20px; background: rgba(231, 76, 60, 0.2); border-radius: 8px; border: 1px solid rgba(231, 76, 60, 0.3); } .footer-disclaimer strong { color: #e74c3c; } """ with gr.Blocks(css=custom_css, title="HeartWatch AI", theme=gr.themes.Soft()) as demo: # Animated Header with Pulsing Heart gr.HTML("""AI-Powered 12-Lead ECG Analysis
đ Select a sample and click Analyze to see results.
", label="Analysis Summary" ) with gr.Row(): sample_ecg_plot = gr.Plot(label="12-Lead ECG Waveform") with gr.Row(): with gr.Column(): sample_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities") with gr.Column(): sample_risk_plot = gr.Plot(label="Risk Assessment Gauges") if sample_names: analyze_sample_btn.click( fn=analyze_sample_by_name, inputs=[sample_radio], outputs=[sample_ecg_plot, sample_diagnosis_plot, sample_risk_plot, sample_summary] ) # Tab 2: Upload Your Own ECG with gr.TabItem("đ¤ Upload Your ECG", id=1): gr.Markdown(""" ### Upload Your Own ECG Recording Have your own ECG data? Upload it here for analysis. """) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload ECG File (.npy)", file_types=[".npy"], type="filepath" ) analyze_btn = gr.Button( "đ Analyze Uploaded ECG", variant="primary", size="lg" ) gr.Markdown(""" **Expected Format:** - **File type:** NumPy array (.npy) - **Shape:** (2500, 12) or (12, 2500) - **Leads:** I, II, III, aVR, aVL, aVF, V1-V6 - **Duration:** 10 seconds at 250 Hz **Tip:** Use `numpy.save('ecg.npy', signal)` to create compatible files. """) with gr.Column(scale=2): upload_summary = gr.HTML( value="đ Upload a .npy file and click Analyze to see results.
", label="Summary" ) with gr.Row(): upload_ecg_plot = gr.Plot(label="12-Lead ECG Waveform") with gr.Row(): with gr.Column(): upload_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities") with gr.Column(): upload_risk_plot = gr.Plot(label="Risk Assessment Gauges") analyze_btn.click( fn=analyze_uploaded_file, inputs=[file_input], outputs=[upload_ecg_plot, upload_diagnosis_plot, upload_risk_plot, upload_summary] ) # Tab 3: About with gr.TabItem("âšī¸ About", id=2): gr.Markdown(""" ## About HeartWatch AI HeartWatch AI is a deep learning-based ECG analysis system powered by state-of-the-art models. ### đ§ AI Models | Model | Description | |-------|-------------| | **77-Class Diagnosis** | Detects 77 different ECG patterns and cardiac conditions | | **LVEF < 40%** | Predicts reduced left ventricular ejection fraction | | **LVEF < 50%** | Predicts moderately reduced ejection fraction | | **5-Year AFib Risk** | Estimates risk of developing Atrial Fibrillation | ### đ Technical Details - **Architecture:** EfficientNetV2 (TorchScript optimized) - **Input:** 12-lead ECG, 10 seconds, 250 Hz - **Inference:** CPU-optimized for accessibility - **Training Data:** Large clinical ECG datasets ### â ī¸ Important Disclaimer **This is a research demonstration tool.** The predictions provided should **NOT** be used for clinical decision-making. Always consult qualified healthcare professionals for medical advice and diagnosis. ### đ References - Models based on the DeepECG project - Sample ECGs from MIT-BIH Arrhythmia Database (PhysioNet) --- *Built with Gradio and PyTorch* """) # Modern Footer with Acknowledgement and Disclaimer gr.HTML(""" """) return demo # Create and launch the demo if __name__ == "__main__": # Create and launch demo demo = create_demo_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )