import gradio as gr import os import sys import time from pathlib import Path from types import SimpleNamespace import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np import torch import wfdb from huggingface_hub import hf_hub_download # spaces.GPU is only available on HF Spaces; stub it for local use import spaces # --- Setup: Checkpoints & Model --- CHECKPOINT_DIR = Path("checkpoints") CHECKPOINT_DIR.mkdir(exist_ok=True) files = ["camel_base.pt", "camel_ecginstruct.pt", "camel_forecast.pt"] repo_id = "CAMEL-ECG/CAMEL" for f in files: cached_path = hf_hub_download(repo_id=repo_id, filename=f) target_link = CHECKPOINT_DIR / f if target_link.is_symlink() and not target_link.exists(): target_link.unlink() print(f"Removed broken symlink for {f}") if not target_link.exists(): target_link.symlink_to(cached_path) print(f"Symlinked {f} -> {cached_path}") from camel.camel_model import CAMEL # Pre-load BOTH models on CPU at startup print("Pre-loading CAMEL base model on CPU...") t0 = time.time() MODEL_BASE = CAMEL(mode='base', device='cpu') print(f" Base model loaded in {time.time() - t0:.1f}s") print("Pre-loading CAMEL forecast model on CPU...") t1 = time.time() MODEL_FORECAST = CAMEL(mode='forecast', device='cpu') print(f" Forecast model loaded in {time.time() - t1:.1f}s") # ────────────────────────────────────────────── # Constants & metadata # ────────────────────────────────────────────── DEMO_DIR = Path("demo") FORECAST_DIR = DEMO_DIR / "forecast" CLASSIFICATION_PROMPTS = [ "Describe the ECG.", "What is the QRS duration?", "What is the R-R interval and heart rate?", ] FORECAST_PROMPT = ( "Analyze the ECG signal and predict the cardiac rhythm for the next 120 seconds.\n" "NORM: Normal ECG\n" "ABNORMAL: Atrial Fibrillation or Atrial Flutter\n" "Output one of: NORM or ABNORMAL." ) # Comments for classification examples CLASSIFICATION_COMMENTS = { "08704_hr": "Left Ventricular Hypertrophy; Non-Specific ST Changes", "12585_hr": "Left Ventricular Hypertrophy, ST-T Changes", } # Metadata for forecasting examples (sample indices relative to original recording) # The wfdb files stored in demo/forecast/ already contain the extracted segments. # These offsets are relative to the *start of the stored segment*. FORECAST_META = { "p09164_s03_600": { "label": "AFIB", "fs": 250, "input_duration_s": 60, "horizon_s": 600, # Original: input=[212243, 227243], afib_onset=232692 # Segment starts at 212243 - 10*250 = 209743 # So relative offsets: "seg_start_orig": 209743, # original sample index of segment start "input_start_orig": 212243, "input_end_orig": 227243, "afib_onset_orig": 232692, "seg_end_orig": 232692 + 30 * 250, # 240192 }, "p01894_s47_300": { "label": "AFIB", "fs": 250, "input_duration_s": 120, "horizon_s": 300, # Original: input=[332516, 362516], afib_onset=385178 # Segment starts at 332516 - 10*250 = 330016 "seg_start_orig": 330016, "input_start_orig": 332516, "input_end_orig": 362516, "afib_onset_orig": 385178, "seg_end_orig": 385178 + 30 * 250, # 392678 }, } # ────────────────────────────────────────────── # Helper functions # ────────────────────────────────────────────── def load_classification_examples(): """List classification ECG stems from demo/.""" if not DEMO_DIR.exists(): return [] return sorted([p.stem for p in DEMO_DIR.glob("*.hea") if not p.stem.startswith(".")]) def load_forecast_examples(): """List forecasting ECG stems from demo/forecast/.""" if not FORECAST_DIR.exists(): return [] return sorted([p.stem for p in FORECAST_DIR.glob("*.hea")]) def read_ecg(path_str): try: record_path = str(Path(path_str).with_suffix("")) return wfdb.rdrecord(record_path) except Exception as e: print(f"Error reading {path_str}: {e}") return None # ────────────────────────────────────────────── # Classification plotting: 12-lead grid # ────────────────────────────────────────────── def _downsample_for_plot(sig, t, max_pts=3000): """Downsample signal for fast matplotlib rendering.""" if len(sig) <= max_pts: return t, sig step = max(1, len(sig) // max_pts) return t[::step], sig[::step] def plot_classification_ecg(record): """Plot all leads in a 4×3 compact grid.""" if not record: return None signals = record.p_signal n_leads = signals.shape[1] t = np.arange(signals.shape[0]) / record.fs cols = 3 rows = max(1, (n_leads + cols - 1) // cols) fig, axes = plt.subplots(rows, cols, figsize=(12, 1.2 * rows + 0.5), sharex=True) if n_leads == 1: axes = np.array([[axes]]) axes = np.atleast_2d(axes) for i in range(n_leads): r, c = divmod(i, cols) ax = axes[r][c] ax.plot(t, signals[:, i], linewidth=0.5, color='#1a1a2e') ax.set_ylabel(record.sig_name[i], fontsize=8, fontweight='bold') ax.tick_params(labelsize=6) ax.grid(True, alpha=0.3) for i in range(n_leads, rows * cols): r, c = divmod(i, cols) axes[r][c].set_visible(False) for c in range(cols): axes[-1][c].set_xlabel("Time (s)", fontsize=8) fig.subplots_adjust(hspace=0.3, wspace=0.3, top=0.97, bottom=0.06) return fig def get_classification_info(record, name=""): """Return info string for a classification ECG.""" if not record: return "" comment = CLASSIFICATION_COMMENTS.get(name, "") lines = [ f"Fs: {record.fs} Hz", f"Duration: {record.sig_len / record.fs:.1f}s", f"Leads: {', '.join(record.sig_name)}", ] if comment: lines.append(f"Diagnosis: {comment}") return "\n".join(lines) # ────────────────────────────────────────────── # Forecasting plotting: 3-panel timeline # ────────────────────────────────────────────── def plot_forecast_ecg(record, name=""): """Plot forecasting ECG with 3 panels: full timeline, input zoom, AFib zoom.""" if not record: return None signals = record.p_signal[:, 0] # single-lead fs = record.fs meta = FORECAST_META.get(name) if meta is None: t = np.arange(len(signals)) / fs td, sd = _downsample_for_plot(t, signals) fig, ax = plt.subplots(figsize=(12, 2.5)) ax.plot(td, sd, linewidth=0.5, color='#1a1a2e') ax.set_xlabel("Time (s)"); ax.set_ylabel("ECG (mV)") ax.set_title("Uploaded ECG (Lead I)", fontsize=11, fontweight='bold') ax.grid(True, alpha=0.3) fig.tight_layout() return fig seg_start = meta["seg_start_orig"] input_start_rel = meta["input_start_orig"] - seg_start input_end_rel = meta["input_end_orig"] - seg_start afib_onset_rel = meta["afib_onset_orig"] - seg_start t = np.arange(len(signals)) / fs fig = plt.figure(figsize=(12, 7)) gs = gridspec.GridSpec(3, 1, height_ratios=[1.5, 1, 1], hspace=0.45) # Panel 1: Full timeline ax1 = fig.add_subplot(gs[0]) ax1.plot(t, signals, linewidth=0.4, color='#1a1a2e', zorder=2) t_in_s = input_start_rel / fs t_in_e = input_end_rel / fs t_af = afib_onset_rel / fs ax1.axvspan(t_in_s, t_in_e, alpha=0.15, color='#2196F3', label='Model Input', zorder=1) ax1.axvspan(t_in_e, t_af, alpha=0.10, color='#9E9E9E', label='Intervening Gap', zorder=1) ax1.axvspan(t_af, t[-1], alpha=0.15, color='#F44336', label='AFib Onset', zorder=1) ax1.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8, zorder=3) ax1.set_title(f"Full Timeline — {name}", fontsize=11, fontweight='bold') ax1.set_xlabel("Time (s)", fontsize=8); ax1.set_ylabel("mV", fontsize=8) ax1.legend(loc='upper right', fontsize=7) ax1.grid(True, alpha=0.3); ax1.tick_params(labelsize=7) # Panel 2: Input zoom (first 10s) ax2 = fig.add_subplot(gs[1]) zoom_end = input_start_rel + min(10 * fs, input_end_rel - input_start_rel) ax2.plot(t[input_start_rel:zoom_end], signals[input_start_rel:zoom_end], linewidth=0.5, color='#1565C0') remaining_s = meta['input_duration_s'] - 10 suffix = f" (first 10s of {meta['input_duration_s']}s)" if remaining_s > 0 else "" ax2.set_title(f"Input Segment{suffix}", fontsize=10, fontweight='bold') if remaining_s > 0: ax2.annotate(f'→ {remaining_s}s more →', xy=(1.0, 0.5), xycoords='axes fraction', fontsize=9, fontweight='bold', color='#1565C0', alpha=0.7, ha='right', va='center') ax2.set_xlabel("Time (s)", fontsize=8); ax2.set_ylabel("mV", fontsize=8) ax2.grid(True, alpha=0.3); ax2.tick_params(labelsize=7) # Panel 3: AFib zoom ax3 = fig.add_subplot(gs[2]) afib_s = max(afib_onset_rel - 10 * fs, 0) afib_e = min(afib_onset_rel + 20 * fs, len(signals)) ax3.plot(t[afib_s:afib_e], signals[afib_s:afib_e], linewidth=0.5, color='#C62828') ax3.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8, label=f'AFib onset ({t_af:.1f}s)') ax3.set_title("AFib Onset Region", fontsize=10, fontweight='bold') ax3.set_xlabel("Time (s)", fontsize=8); ax3.set_ylabel("mV", fontsize=8) ax3.legend(loc='upper right', fontsize=7) ax3.grid(True, alpha=0.3); ax3.tick_params(labelsize=7) fig.subplots_adjust(top=0.95, bottom=0.06) return fig def get_forecast_info(record, name=""): """Return info string for a forecasting ECG.""" if not record: return "" meta = FORECAST_META.get(name) lines = [ f"Fs: {record.fs} Hz", f"Leads: {', '.join(record.sig_name)}", f"Segment duration: {record.sig_len / record.fs:.1f}s", ] if meta: lines += [ f"Input duration: {meta['input_duration_s']}s", f"Ground truth: {meta['label']}", f"AFib onset: {(meta['afib_onset_orig'] - meta['input_end_orig']) / meta['fs']:.0f}s after input ends", ] return "\n".join(lines) # ────────────────────────────────────────────── # Event handlers: Classification # ────────────────────────────────────────────── def on_cls_example_select(name): if not name: return None, "", None path = DEMO_DIR / name record = read_ecg(str(path)) return plot_classification_ecg(record), get_classification_info(record, name), str(path) def on_cls_upload(files): if not files: return None, "", None paths = [Path(f.name) for f in files] hea = next((p for p in paths if p.suffix == '.hea'), None) if not hea: return None, "Missing .hea file", None record = read_ecg(str(hea)) return plot_classification_ecg(record), get_classification_info(record), str(hea.with_suffix("")) def cls_load_first(): """Pre-load first classification example on page open.""" examples = load_classification_examples() if not examples: return gr.update(), None, "", None first = examples[0] plot, info, path = on_cls_example_select(first) return gr.update(value=first), plot, info, path # ────────────────────────────────────────────── # Event handlers: Forecasting # ────────────────────────────────────────────── def on_fc_example_select(name): if not name: return None, "", None path = FORECAST_DIR / name record = read_ecg(str(path)) return plot_forecast_ecg(record, name), get_forecast_info(record, name), str(path) def on_fc_upload(files): if not files: return None, "", None paths = [Path(f.name) for f in files] hea = next((p for p in paths if p.suffix == '.hea'), None) if not hea: return None, "Missing .hea file", None record = read_ecg(str(hea)) return plot_forecast_ecg(record), get_forecast_info(record), str(hea.with_suffix("")) def fc_load_first(): """Pre-load first forecast example on page open.""" examples = load_forecast_examples() if not examples: return gr.update(), None, "", None first = examples[0] plot, info, path = on_fc_example_select(first) return gr.update(value=first), plot, info, path # ────────────────────────────────────────────── # Inference # ────────────────────────────────────────────── @spaces.GPU def run_cls_inference(ecg_path, prompt): if not ecg_path: return "Select an ECG first." t0 = time.time() MODEL_BASE.move_to_device('cuda') print(f" Base model .to(cuda): {time.time() - t0:.2f}s") with torch.inference_mode(): args = SimpleNamespace( mode='base', text=prompt, ecgs=[ecg_path], device='cuda', ecg_configs=None, json=None, temperature=0.0, top_k=64, top_p=0.95, min_p=0.0, max_new_tokens=512 ) output, _ = MODEL_BASE.run(args) # Move back to CPU to free GPU for other models MODEL_BASE.move_to_device('cpu') torch.cuda.empty_cache() print(f" Classification inference: {time.time() - t0:.2f}s") return output @spaces.GPU def run_fc_inference(ecg_path, name): if not ecg_path: return "Select an ECG first." # For known examples, use ecg_configs to slice the input portion meta = FORECAST_META.get(name) ecg_configs = None if meta: # The wfdb file starts at seg_start_orig. The model input is # [input_start_orig, input_end_orig]. Convert to seconds relative # to the start of the stored segment. input_start_sec = (meta["input_start_orig"] - meta["seg_start_orig"]) // meta["fs"] input_end_sec = (meta["input_end_orig"] - meta["seg_start_orig"]) // meta["fs"] ecg_configs = [f"start:{input_start_sec};end:{input_end_sec}"] t0 = time.time() MODEL_FORECAST.move_to_device('cuda') print(f" Forecast model .to(cuda): {time.time() - t0:.2f}s") with torch.inference_mode(): args = SimpleNamespace( mode='forecast', text=FORECAST_PROMPT, ecgs=[ecg_path], device='cuda', ecg_configs=ecg_configs, json=None, temperature=0.0, top_k=64, top_p=0.95, min_p=0.0, max_new_tokens=512 ) output, _ = MODEL_FORECAST.run(args) # Move back to CPU to free GPU for other models MODEL_FORECAST.move_to_device('cpu') torch.cuda.empty_cache() print(f" Forecast inference: {time.time() - t0:.2f}s") return output # ────────────────────────────────────────────── # UI # ────────────────────────────────────────────── with gr.Blocks(title="CAMEL ECG", css=""" .gr-button-primary { min-width: 120px; } .prompt-display { background: var(--background-fill-secondary) !important; border: 1px solid var(--border-color-primary); border-radius: 8px; padding: 12px; font-size: 0.9em; } .model-output textarea { background: rgba(100, 180, 255, 0.08) !important; color: var(--body-text-color) !important; border: 1px solid rgba(100, 180, 255, 0.25) !important; } """) as demo: gr.Markdown("# 🐪 CAMEL ECG Model") gr.Markdown("Cardiac AI Model for ECG analysis and rhythm forecasting.") with gr.Tabs(): # ═══════════════════════════════ # Tab 1: Classification # ═══════════════════════════════ with gr.Tab("🩺 Classification"): cls_path = gr.State() with gr.Row(): with gr.Column(scale=1, min_width=280): with gr.Tabs(): with gr.Tab("Example"): cls_dd = gr.Dropdown( load_classification_examples(), label="Select Example ECG", ) with gr.Tab("Upload"): cls_upload = gr.File(file_count="multiple", label="Upload .hea + .dat") cls_prompt = gr.Dropdown( CLASSIFICATION_PROMPTS, value=CLASSIFICATION_PROMPTS[0], label="Prompt", ) cls_run_btn = gr.Button("Run Inference", variant="primary") cls_out = gr.Textbox(label="Model Output", lines=6, interactive=False, elem_classes=["model-output"]) with gr.Column(scale=3): cls_plot = gr.Plot(label="ECG") cls_info = gr.Textbox(label="Ground-truth", lines=3, interactive=False) cls_dd.change(on_cls_example_select, cls_dd, [cls_plot, cls_info, cls_path]) cls_upload.upload(on_cls_upload, cls_upload, [cls_plot, cls_info, cls_path]) cls_run_btn.click(run_cls_inference, [cls_path, cls_prompt], cls_out) # ═══════════════════════════════ # Tab 2: Forecasting # ═══════════════════════════════ with gr.Tab("🔮 Forecasting"): fc_path = gr.State() fc_name = gr.State() with gr.Row(): with gr.Column(scale=1, min_width=280): with gr.Tabs(): with gr.Tab("Example"): fc_dd = gr.Dropdown( load_forecast_examples(), label="Select Example ECG", ) with gr.Tab("Upload"): fc_upload = gr.File(file_count="multiple", label="Upload .hea + .dat") gr.Markdown( f"**Prompt (fixed):**\n\n{FORECAST_PROMPT}", elem_classes=["prompt-display"], ) fc_run_btn = gr.Button("Run Inference", variant="primary") fc_out = gr.Textbox(label="Model Output", lines=5, interactive=False, elem_classes=["model-output"]) with gr.Column(scale=3): fc_plot = gr.Plot(label="ECG Timeline") fc_info = gr.Textbox(label="Info", lines=4, interactive=False) def _fc_select_wrapper(name): plot, info, path = on_fc_example_select(name) return plot, info, path, name def _fc_upload_wrapper(files): plot, info, path = on_fc_upload(files) return plot, info, path, "" fc_dd.change(_fc_select_wrapper, fc_dd, [fc_plot, fc_info, fc_path, fc_name]) fc_upload.upload(_fc_upload_wrapper, fc_upload, [fc_plot, fc_info, fc_path, fc_name]) fc_run_btn.click(run_fc_inference, [fc_path, fc_name], fc_out) # Pre-load first examples on page open demo.load(cls_load_first, outputs=[cls_dd, cls_plot, cls_info, cls_path]) demo.load(fc_load_first, outputs=[fc_dd, fc_plot, fc_info, fc_path]) if __name__ == "__main__": demo.launch()