| import gradio as gr |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| import numpy as np |
| import torch |
| import wfdb |
| from huggingface_hub import hf_hub_download |
|
|
| |
| import spaces |
|
|
| |
| CHECKPOINT_DIR = Path("checkpoints") |
| CHECKPOINT_DIR.mkdir(exist_ok=True) |
|
|
| files = ["camel_base.pt", "camel_ecginstruct.pt", "camel_forecast.pt"] |
| repo_id = "CAMEL-ECG/CAMEL" |
|
|
| for f in files: |
| cached_path = hf_hub_download(repo_id=repo_id, filename=f) |
| target_link = CHECKPOINT_DIR / f |
| if target_link.is_symlink() and not target_link.exists(): |
| target_link.unlink() |
| print(f"Removed broken symlink for {f}") |
| if not target_link.exists(): |
| target_link.symlink_to(cached_path) |
| print(f"Symlinked {f} -> {cached_path}") |
|
|
| from camel.camel_model import CAMEL |
|
|
| |
| print("Pre-loading CAMEL base model on CPU...") |
| t0 = time.time() |
| MODEL_BASE = CAMEL(mode='base', device='cpu') |
| print(f" Base model loaded in {time.time() - t0:.1f}s") |
|
|
| print("Pre-loading CAMEL forecast model on CPU...") |
| t1 = time.time() |
| MODEL_FORECAST = CAMEL(mode='forecast', device='cpu') |
| print(f" Forecast model loaded in {time.time() - t1:.1f}s") |
|
|
|
|
| |
| |
| |
| DEMO_DIR = Path("demo") |
| FORECAST_DIR = DEMO_DIR / "forecast" |
|
|
| CLASSIFICATION_PROMPTS = [ |
| "Describe the ECG.", |
| "What is the QRS duration?", |
| "What is the R-R interval and heart rate?", |
| ] |
|
|
| FORECAST_PROMPT = ( |
| "Analyze the ECG signal and predict the cardiac rhythm for the next 120 seconds.\n" |
| "NORM: Normal ECG\n" |
| "ABNORMAL: Atrial Fibrillation or Atrial Flutter\n" |
| "Output one of: NORM or ABNORMAL." |
| ) |
|
|
| |
| CLASSIFICATION_COMMENTS = { |
| "08704_hr": "Left Ventricular Hypertrophy; Non-Specific ST Changes", |
| "12585_hr": "Left Ventricular Hypertrophy, ST-T Changes", |
| } |
|
|
| |
| |
| |
| FORECAST_META = { |
| "p09164_s03_600": { |
| "label": "AFIB", |
| "fs": 250, |
| "input_duration_s": 60, |
| "horizon_s": 600, |
| |
| |
| |
| "seg_start_orig": 209743, |
| "input_start_orig": 212243, |
| "input_end_orig": 227243, |
| "afib_onset_orig": 232692, |
| "seg_end_orig": 232692 + 30 * 250, |
| }, |
| "p01894_s47_300": { |
| "label": "AFIB", |
| "fs": 250, |
| "input_duration_s": 120, |
| "horizon_s": 300, |
| |
| |
| "seg_start_orig": 330016, |
| "input_start_orig": 332516, |
| "input_end_orig": 362516, |
| "afib_onset_orig": 385178, |
| "seg_end_orig": 385178 + 30 * 250, |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def load_classification_examples(): |
| """List classification ECG stems from demo/.""" |
| if not DEMO_DIR.exists(): |
| return [] |
| return sorted([p.stem for p in DEMO_DIR.glob("*.hea") if not p.stem.startswith(".")]) |
|
|
|
|
| def load_forecast_examples(): |
| """List forecasting ECG stems from demo/forecast/.""" |
| if not FORECAST_DIR.exists(): |
| return [] |
| return sorted([p.stem for p in FORECAST_DIR.glob("*.hea")]) |
|
|
|
|
| def read_ecg(path_str): |
| try: |
| record_path = str(Path(path_str).with_suffix("")) |
| return wfdb.rdrecord(record_path) |
| except Exception as e: |
| print(f"Error reading {path_str}: {e}") |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def _downsample_for_plot(sig, t, max_pts=3000): |
| """Downsample signal for fast matplotlib rendering.""" |
| if len(sig) <= max_pts: |
| return t, sig |
| step = max(1, len(sig) // max_pts) |
| return t[::step], sig[::step] |
|
|
|
|
| def plot_classification_ecg(record): |
| """Plot all leads in a 4ร3 compact grid.""" |
| if not record: |
| return None |
| signals = record.p_signal |
| n_leads = signals.shape[1] |
| t = np.arange(signals.shape[0]) / record.fs |
|
|
| cols = 3 |
| rows = max(1, (n_leads + cols - 1) // cols) |
| fig, axes = plt.subplots(rows, cols, figsize=(12, 1.2 * rows + 0.5), sharex=True) |
| if n_leads == 1: |
| axes = np.array([[axes]]) |
| axes = np.atleast_2d(axes) |
|
|
| for i in range(n_leads): |
| r, c = divmod(i, cols) |
| ax = axes[r][c] |
| ax.plot(t, signals[:, i], linewidth=0.5, color='#1a1a2e') |
| ax.set_ylabel(record.sig_name[i], fontsize=8, fontweight='bold') |
| ax.tick_params(labelsize=6) |
| ax.grid(True, alpha=0.3) |
|
|
| for i in range(n_leads, rows * cols): |
| r, c = divmod(i, cols) |
| axes[r][c].set_visible(False) |
|
|
| for c in range(cols): |
| axes[-1][c].set_xlabel("Time (s)", fontsize=8) |
| fig.subplots_adjust(hspace=0.3, wspace=0.3, top=0.97, bottom=0.06) |
| return fig |
|
|
|
|
| def get_classification_info(record, name=""): |
| """Return info string for a classification ECG.""" |
| if not record: |
| return "" |
| comment = CLASSIFICATION_COMMENTS.get(name, "") |
| lines = [ |
| f"Fs: {record.fs} Hz", |
| f"Duration: {record.sig_len / record.fs:.1f}s", |
| f"Leads: {', '.join(record.sig_name)}", |
| ] |
| if comment: |
| lines.append(f"Diagnosis: {comment}") |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| def plot_forecast_ecg(record, name=""): |
| """Plot forecasting ECG with 3 panels: full timeline, input zoom, AFib zoom.""" |
| if not record: |
| return None |
|
|
| signals = record.p_signal[:, 0] |
| fs = record.fs |
| meta = FORECAST_META.get(name) |
|
|
| if meta is None: |
| t = np.arange(len(signals)) / fs |
| td, sd = _downsample_for_plot(t, signals) |
| fig, ax = plt.subplots(figsize=(12, 2.5)) |
| ax.plot(td, sd, linewidth=0.5, color='#1a1a2e') |
| ax.set_xlabel("Time (s)"); ax.set_ylabel("ECG (mV)") |
| ax.set_title("Uploaded ECG (Lead I)", fontsize=11, fontweight='bold') |
| ax.grid(True, alpha=0.3) |
| fig.tight_layout() |
| return fig |
|
|
| seg_start = meta["seg_start_orig"] |
| input_start_rel = meta["input_start_orig"] - seg_start |
| input_end_rel = meta["input_end_orig"] - seg_start |
| afib_onset_rel = meta["afib_onset_orig"] - seg_start |
| t = np.arange(len(signals)) / fs |
|
|
| fig = plt.figure(figsize=(12, 7)) |
| gs = gridspec.GridSpec(3, 1, height_ratios=[1.5, 1, 1], hspace=0.45) |
|
|
| |
| ax1 = fig.add_subplot(gs[0]) |
| ax1.plot(t, signals, linewidth=0.4, color='#1a1a2e', zorder=2) |
| t_in_s = input_start_rel / fs |
| t_in_e = input_end_rel / fs |
| t_af = afib_onset_rel / fs |
| ax1.axvspan(t_in_s, t_in_e, alpha=0.15, color='#2196F3', label='Model Input', zorder=1) |
| ax1.axvspan(t_in_e, t_af, alpha=0.10, color='#9E9E9E', label='Intervening Gap', zorder=1) |
| ax1.axvspan(t_af, t[-1], alpha=0.15, color='#F44336', label='AFib Onset', zorder=1) |
| ax1.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8, zorder=3) |
| ax1.set_title(f"Full Timeline โ {name}", fontsize=11, fontweight='bold') |
| ax1.set_xlabel("Time (s)", fontsize=8); ax1.set_ylabel("mV", fontsize=8) |
| ax1.legend(loc='upper right', fontsize=7) |
| ax1.grid(True, alpha=0.3); ax1.tick_params(labelsize=7) |
|
|
| |
| ax2 = fig.add_subplot(gs[1]) |
| zoom_end = input_start_rel + min(10 * fs, input_end_rel - input_start_rel) |
| ax2.plot(t[input_start_rel:zoom_end], signals[input_start_rel:zoom_end], linewidth=0.5, color='#1565C0') |
| remaining_s = meta['input_duration_s'] - 10 |
| suffix = f" (first 10s of {meta['input_duration_s']}s)" if remaining_s > 0 else "" |
| ax2.set_title(f"Input Segment{suffix}", fontsize=10, fontweight='bold') |
| if remaining_s > 0: |
| ax2.annotate(f'โ {remaining_s}s more โ', xy=(1.0, 0.5), xycoords='axes fraction', |
| fontsize=9, fontweight='bold', color='#1565C0', alpha=0.7, ha='right', va='center') |
| ax2.set_xlabel("Time (s)", fontsize=8); ax2.set_ylabel("mV", fontsize=8) |
| ax2.grid(True, alpha=0.3); ax2.tick_params(labelsize=7) |
|
|
| |
| ax3 = fig.add_subplot(gs[2]) |
| afib_s = max(afib_onset_rel - 10 * fs, 0) |
| afib_e = min(afib_onset_rel + 20 * fs, len(signals)) |
| ax3.plot(t[afib_s:afib_e], signals[afib_s:afib_e], linewidth=0.5, color='#C62828') |
| ax3.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8, |
| label=f'AFib onset ({t_af:.1f}s)') |
| ax3.set_title("AFib Onset Region", fontsize=10, fontweight='bold') |
| ax3.set_xlabel("Time (s)", fontsize=8); ax3.set_ylabel("mV", fontsize=8) |
| ax3.legend(loc='upper right', fontsize=7) |
| ax3.grid(True, alpha=0.3); ax3.tick_params(labelsize=7) |
|
|
| fig.subplots_adjust(top=0.95, bottom=0.06) |
| return fig |
|
|
|
|
| def get_forecast_info(record, name=""): |
| """Return info string for a forecasting ECG.""" |
| if not record: |
| return "" |
| meta = FORECAST_META.get(name) |
| lines = [ |
| f"Fs: {record.fs} Hz", |
| f"Leads: {', '.join(record.sig_name)}", |
| f"Segment duration: {record.sig_len / record.fs:.1f}s", |
| ] |
| if meta: |
| lines += [ |
| f"Input duration: {meta['input_duration_s']}s", |
| f"Ground truth: {meta['label']}", |
| f"AFib onset: {(meta['afib_onset_orig'] - meta['input_end_orig']) / meta['fs']:.0f}s after input ends", |
| ] |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| def on_cls_example_select(name): |
| if not name: |
| return None, "", None |
| path = DEMO_DIR / name |
| record = read_ecg(str(path)) |
| return plot_classification_ecg(record), get_classification_info(record, name), str(path) |
|
|
|
|
| def on_cls_upload(files): |
| if not files: |
| return None, "", None |
| paths = [Path(f.name) for f in files] |
| hea = next((p for p in paths if p.suffix == '.hea'), None) |
| if not hea: |
| return None, "Missing .hea file", None |
| record = read_ecg(str(hea)) |
| return plot_classification_ecg(record), get_classification_info(record), str(hea.with_suffix("")) |
|
|
|
|
| def cls_load_first(): |
| """Pre-load first classification example on page open.""" |
| examples = load_classification_examples() |
| if not examples: |
| return gr.update(), None, "", None |
| first = examples[0] |
| plot, info, path = on_cls_example_select(first) |
| return gr.update(value=first), plot, info, path |
|
|
|
|
| |
| |
| |
|
|
| def on_fc_example_select(name): |
| if not name: |
| return None, "", None |
| path = FORECAST_DIR / name |
| record = read_ecg(str(path)) |
| return plot_forecast_ecg(record, name), get_forecast_info(record, name), str(path) |
|
|
|
|
| def on_fc_upload(files): |
| if not files: |
| return None, "", None |
| paths = [Path(f.name) for f in files] |
| hea = next((p for p in paths if p.suffix == '.hea'), None) |
| if not hea: |
| return None, "Missing .hea file", None |
| record = read_ecg(str(hea)) |
| return plot_forecast_ecg(record), get_forecast_info(record), str(hea.with_suffix("")) |
|
|
|
|
| def fc_load_first(): |
| """Pre-load first forecast example on page open.""" |
| examples = load_forecast_examples() |
| if not examples: |
| return gr.update(), None, "", None |
| first = examples[0] |
| plot, info, path = on_fc_example_select(first) |
| return gr.update(value=first), plot, info, path |
|
|
|
|
| |
| |
| |
|
|
| @spaces.GPU |
| def run_cls_inference(ecg_path, prompt): |
| if not ecg_path: |
| return "Select an ECG first." |
| t0 = time.time() |
| MODEL_BASE.move_to_device('cuda') |
| print(f" Base model .to(cuda): {time.time() - t0:.2f}s") |
|
|
| with torch.inference_mode(): |
| args = SimpleNamespace( |
| mode='base', text=prompt, ecgs=[ecg_path], device='cuda', |
| ecg_configs=None, json=None, temperature=0.0, top_k=64, |
| top_p=0.95, min_p=0.0, max_new_tokens=512 |
| ) |
| output, _ = MODEL_BASE.run(args) |
|
|
| |
| MODEL_BASE.move_to_device('cpu') |
| torch.cuda.empty_cache() |
|
|
| print(f" Classification inference: {time.time() - t0:.2f}s") |
| return output |
|
|
|
|
| @spaces.GPU |
| def run_fc_inference(ecg_path, name): |
| if not ecg_path: |
| return "Select an ECG first." |
|
|
| |
| meta = FORECAST_META.get(name) |
| ecg_configs = None |
| if meta: |
| |
| |
| |
| input_start_sec = (meta["input_start_orig"] - meta["seg_start_orig"]) // meta["fs"] |
| input_end_sec = (meta["input_end_orig"] - meta["seg_start_orig"]) // meta["fs"] |
| ecg_configs = [f"start:{input_start_sec};end:{input_end_sec}"] |
|
|
| t0 = time.time() |
| MODEL_FORECAST.move_to_device('cuda') |
| print(f" Forecast model .to(cuda): {time.time() - t0:.2f}s") |
|
|
| with torch.inference_mode(): |
| args = SimpleNamespace( |
| mode='forecast', text=FORECAST_PROMPT, ecgs=[ecg_path], device='cuda', |
| ecg_configs=ecg_configs, json=None, temperature=0.0, top_k=64, |
| top_p=0.95, min_p=0.0, max_new_tokens=512 |
| ) |
| output, _ = MODEL_FORECAST.run(args) |
|
|
| |
| MODEL_FORECAST.move_to_device('cpu') |
| torch.cuda.empty_cache() |
|
|
| print(f" Forecast inference: {time.time() - t0:.2f}s") |
| return output |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="CAMEL ECG", css=""" |
| .gr-button-primary { min-width: 120px; } |
| .prompt-display { background: var(--background-fill-secondary) !important; border: 1px solid var(--border-color-primary); border-radius: 8px; padding: 12px; font-size: 0.9em; } |
| .model-output textarea { background: rgba(100, 180, 255, 0.08) !important; color: var(--body-text-color) !important; border: 1px solid rgba(100, 180, 255, 0.25) !important; } |
| """) as demo: |
| gr.Markdown("# ๐ช CAMEL ECG Model") |
| gr.Markdown("Cardiac AI Model for ECG analysis and rhythm forecasting.") |
|
|
| with gr.Tabs(): |
| |
| |
| |
| with gr.Tab("๐ฉบ Classification"): |
| cls_path = gr.State() |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=280): |
| with gr.Tabs(): |
| with gr.Tab("Example"): |
| cls_dd = gr.Dropdown( |
| load_classification_examples(), |
| label="Select Example ECG", |
| ) |
| with gr.Tab("Upload"): |
| cls_upload = gr.File(file_count="multiple", label="Upload .hea + .dat") |
|
|
| cls_prompt = gr.Dropdown( |
| CLASSIFICATION_PROMPTS, |
| value=CLASSIFICATION_PROMPTS[0], |
| label="Prompt", |
| ) |
| cls_run_btn = gr.Button("Run Inference", variant="primary") |
| cls_out = gr.Textbox(label="Model Output", lines=6, interactive=False, elem_classes=["model-output"]) |
|
|
| with gr.Column(scale=3): |
| cls_plot = gr.Plot(label="ECG") |
| cls_info = gr.Textbox(label="Ground-truth", lines=3, interactive=False) |
|
|
| cls_dd.change(on_cls_example_select, cls_dd, [cls_plot, cls_info, cls_path]) |
| cls_upload.upload(on_cls_upload, cls_upload, [cls_plot, cls_info, cls_path]) |
| cls_run_btn.click(run_cls_inference, [cls_path, cls_prompt], cls_out) |
|
|
| |
| |
| |
| with gr.Tab("๐ฎ Forecasting"): |
| fc_path = gr.State() |
| fc_name = gr.State() |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=280): |
| with gr.Tabs(): |
| with gr.Tab("Example"): |
| fc_dd = gr.Dropdown( |
| load_forecast_examples(), |
| label="Select Example ECG", |
| ) |
| with gr.Tab("Upload"): |
| fc_upload = gr.File(file_count="multiple", label="Upload .hea + .dat") |
|
|
| gr.Markdown( |
| f"**Prompt (fixed):**\n\n{FORECAST_PROMPT}", |
| elem_classes=["prompt-display"], |
| ) |
| fc_run_btn = gr.Button("Run Inference", variant="primary") |
| fc_out = gr.Textbox(label="Model Output", lines=5, interactive=False, elem_classes=["model-output"]) |
|
|
| with gr.Column(scale=3): |
| fc_plot = gr.Plot(label="ECG Timeline") |
| fc_info = gr.Textbox(label="Info", lines=4, interactive=False) |
|
|
| def _fc_select_wrapper(name): |
| plot, info, path = on_fc_example_select(name) |
| return plot, info, path, name |
|
|
| def _fc_upload_wrapper(files): |
| plot, info, path = on_fc_upload(files) |
| return plot, info, path, "" |
|
|
| fc_dd.change(_fc_select_wrapper, fc_dd, [fc_plot, fc_info, fc_path, fc_name]) |
| fc_upload.upload(_fc_upload_wrapper, fc_upload, [fc_plot, fc_info, fc_path, fc_name]) |
| fc_run_btn.click(run_fc_inference, [fc_path, fc_name], fc_out) |
|
|
| |
| demo.load(cls_load_first, outputs=[cls_dd, cls_plot, cls_info, cls_path]) |
| demo.load(fc_load_first, outputs=[fc_dd, fc_plot, fc_info, fc_path]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|