CAMEL / app.py
Mayank Keoliya
done
00ba54f
import gradio as gr
import os
import sys
import time
from pathlib import Path
from types import SimpleNamespace
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import torch
import wfdb
from huggingface_hub import hf_hub_download
# spaces.GPU is only available on HF Spaces; stub it for local use
import spaces
# --- Setup: Checkpoints & Model ---
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)
files = ["camel_base.pt", "camel_ecginstruct.pt", "camel_forecast.pt"]
repo_id = "CAMEL-ECG/CAMEL"
for f in files:
cached_path = hf_hub_download(repo_id=repo_id, filename=f)
target_link = CHECKPOINT_DIR / f
if target_link.is_symlink() and not target_link.exists():
target_link.unlink()
print(f"Removed broken symlink for {f}")
if not target_link.exists():
target_link.symlink_to(cached_path)
print(f"Symlinked {f} -> {cached_path}")
from camel.camel_model import CAMEL
# Pre-load BOTH models on CPU at startup
print("Pre-loading CAMEL base model on CPU...")
t0 = time.time()
MODEL_BASE = CAMEL(mode='base', device='cpu')
print(f" Base model loaded in {time.time() - t0:.1f}s")
print("Pre-loading CAMEL forecast model on CPU...")
t1 = time.time()
MODEL_FORECAST = CAMEL(mode='forecast', device='cpu')
print(f" Forecast model loaded in {time.time() - t1:.1f}s")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Constants & metadata
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
DEMO_DIR = Path("demo")
FORECAST_DIR = DEMO_DIR / "forecast"
CLASSIFICATION_PROMPTS = [
"Describe the ECG.",
"What is the QRS duration?",
"What is the R-R interval and heart rate?",
]
FORECAST_PROMPT = (
"Analyze the ECG signal and predict the cardiac rhythm for the next 120 seconds.\n"
"NORM: Normal ECG\n"
"ABNORMAL: Atrial Fibrillation or Atrial Flutter\n"
"Output one of: NORM or ABNORMAL."
)
# Comments for classification examples
CLASSIFICATION_COMMENTS = {
"08704_hr": "Left Ventricular Hypertrophy; Non-Specific ST Changes",
"12585_hr": "Left Ventricular Hypertrophy, ST-T Changes",
}
# Metadata for forecasting examples (sample indices relative to original recording)
# The wfdb files stored in demo/forecast/ already contain the extracted segments.
# These offsets are relative to the *start of the stored segment*.
FORECAST_META = {
"p09164_s03_600": {
"label": "AFIB",
"fs": 250,
"input_duration_s": 60,
"horizon_s": 600,
# Original: input=[212243, 227243], afib_onset=232692
# Segment starts at 212243 - 10*250 = 209743
# So relative offsets:
"seg_start_orig": 209743, # original sample index of segment start
"input_start_orig": 212243,
"input_end_orig": 227243,
"afib_onset_orig": 232692,
"seg_end_orig": 232692 + 30 * 250, # 240192
},
"p01894_s47_300": {
"label": "AFIB",
"fs": 250,
"input_duration_s": 120,
"horizon_s": 300,
# Original: input=[332516, 362516], afib_onset=385178
# Segment starts at 332516 - 10*250 = 330016
"seg_start_orig": 330016,
"input_start_orig": 332516,
"input_end_orig": 362516,
"afib_onset_orig": 385178,
"seg_end_orig": 385178 + 30 * 250, # 392678
},
}
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Helper functions
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_classification_examples():
"""List classification ECG stems from demo/."""
if not DEMO_DIR.exists():
return []
return sorted([p.stem for p in DEMO_DIR.glob("*.hea") if not p.stem.startswith(".")])
def load_forecast_examples():
"""List forecasting ECG stems from demo/forecast/."""
if not FORECAST_DIR.exists():
return []
return sorted([p.stem for p in FORECAST_DIR.glob("*.hea")])
def read_ecg(path_str):
try:
record_path = str(Path(path_str).with_suffix(""))
return wfdb.rdrecord(record_path)
except Exception as e:
print(f"Error reading {path_str}: {e}")
return None
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Classification plotting: 12-lead grid
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _downsample_for_plot(sig, t, max_pts=3000):
"""Downsample signal for fast matplotlib rendering."""
if len(sig) <= max_pts:
return t, sig
step = max(1, len(sig) // max_pts)
return t[::step], sig[::step]
def plot_classification_ecg(record):
"""Plot all leads in a 4ร—3 compact grid."""
if not record:
return None
signals = record.p_signal
n_leads = signals.shape[1]
t = np.arange(signals.shape[0]) / record.fs
cols = 3
rows = max(1, (n_leads + cols - 1) // cols)
fig, axes = plt.subplots(rows, cols, figsize=(12, 1.2 * rows + 0.5), sharex=True)
if n_leads == 1:
axes = np.array([[axes]])
axes = np.atleast_2d(axes)
for i in range(n_leads):
r, c = divmod(i, cols)
ax = axes[r][c]
ax.plot(t, signals[:, i], linewidth=0.5, color='#1a1a2e')
ax.set_ylabel(record.sig_name[i], fontsize=8, fontweight='bold')
ax.tick_params(labelsize=6)
ax.grid(True, alpha=0.3)
for i in range(n_leads, rows * cols):
r, c = divmod(i, cols)
axes[r][c].set_visible(False)
for c in range(cols):
axes[-1][c].set_xlabel("Time (s)", fontsize=8)
fig.subplots_adjust(hspace=0.3, wspace=0.3, top=0.97, bottom=0.06)
return fig
def get_classification_info(record, name=""):
"""Return info string for a classification ECG."""
if not record:
return ""
comment = CLASSIFICATION_COMMENTS.get(name, "")
lines = [
f"Fs: {record.fs} Hz",
f"Duration: {record.sig_len / record.fs:.1f}s",
f"Leads: {', '.join(record.sig_name)}",
]
if comment:
lines.append(f"Diagnosis: {comment}")
return "\n".join(lines)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Forecasting plotting: 3-panel timeline
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def plot_forecast_ecg(record, name=""):
"""Plot forecasting ECG with 3 panels: full timeline, input zoom, AFib zoom."""
if not record:
return None
signals = record.p_signal[:, 0] # single-lead
fs = record.fs
meta = FORECAST_META.get(name)
if meta is None:
t = np.arange(len(signals)) / fs
td, sd = _downsample_for_plot(t, signals)
fig, ax = plt.subplots(figsize=(12, 2.5))
ax.plot(td, sd, linewidth=0.5, color='#1a1a2e')
ax.set_xlabel("Time (s)"); ax.set_ylabel("ECG (mV)")
ax.set_title("Uploaded ECG (Lead I)", fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3)
fig.tight_layout()
return fig
seg_start = meta["seg_start_orig"]
input_start_rel = meta["input_start_orig"] - seg_start
input_end_rel = meta["input_end_orig"] - seg_start
afib_onset_rel = meta["afib_onset_orig"] - seg_start
t = np.arange(len(signals)) / fs
fig = plt.figure(figsize=(12, 7))
gs = gridspec.GridSpec(3, 1, height_ratios=[1.5, 1, 1], hspace=0.45)
# Panel 1: Full timeline
ax1 = fig.add_subplot(gs[0])
ax1.plot(t, signals, linewidth=0.4, color='#1a1a2e', zorder=2)
t_in_s = input_start_rel / fs
t_in_e = input_end_rel / fs
t_af = afib_onset_rel / fs
ax1.axvspan(t_in_s, t_in_e, alpha=0.15, color='#2196F3', label='Model Input', zorder=1)
ax1.axvspan(t_in_e, t_af, alpha=0.10, color='#9E9E9E', label='Intervening Gap', zorder=1)
ax1.axvspan(t_af, t[-1], alpha=0.15, color='#F44336', label='AFib Onset', zorder=1)
ax1.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8, zorder=3)
ax1.set_title(f"Full Timeline โ€” {name}", fontsize=11, fontweight='bold')
ax1.set_xlabel("Time (s)", fontsize=8); ax1.set_ylabel("mV", fontsize=8)
ax1.legend(loc='upper right', fontsize=7)
ax1.grid(True, alpha=0.3); ax1.tick_params(labelsize=7)
# Panel 2: Input zoom (first 10s)
ax2 = fig.add_subplot(gs[1])
zoom_end = input_start_rel + min(10 * fs, input_end_rel - input_start_rel)
ax2.plot(t[input_start_rel:zoom_end], signals[input_start_rel:zoom_end], linewidth=0.5, color='#1565C0')
remaining_s = meta['input_duration_s'] - 10
suffix = f" (first 10s of {meta['input_duration_s']}s)" if remaining_s > 0 else ""
ax2.set_title(f"Input Segment{suffix}", fontsize=10, fontweight='bold')
if remaining_s > 0:
ax2.annotate(f'โ†’ {remaining_s}s more โ†’', xy=(1.0, 0.5), xycoords='axes fraction',
fontsize=9, fontweight='bold', color='#1565C0', alpha=0.7, ha='right', va='center')
ax2.set_xlabel("Time (s)", fontsize=8); ax2.set_ylabel("mV", fontsize=8)
ax2.grid(True, alpha=0.3); ax2.tick_params(labelsize=7)
# Panel 3: AFib zoom
ax3 = fig.add_subplot(gs[2])
afib_s = max(afib_onset_rel - 10 * fs, 0)
afib_e = min(afib_onset_rel + 20 * fs, len(signals))
ax3.plot(t[afib_s:afib_e], signals[afib_s:afib_e], linewidth=0.5, color='#C62828')
ax3.axvline(t_af, color='#D32F2F', linestyle='--', linewidth=1.5, alpha=0.8,
label=f'AFib onset ({t_af:.1f}s)')
ax3.set_title("AFib Onset Region", fontsize=10, fontweight='bold')
ax3.set_xlabel("Time (s)", fontsize=8); ax3.set_ylabel("mV", fontsize=8)
ax3.legend(loc='upper right', fontsize=7)
ax3.grid(True, alpha=0.3); ax3.tick_params(labelsize=7)
fig.subplots_adjust(top=0.95, bottom=0.06)
return fig
def get_forecast_info(record, name=""):
"""Return info string for a forecasting ECG."""
if not record:
return ""
meta = FORECAST_META.get(name)
lines = [
f"Fs: {record.fs} Hz",
f"Leads: {', '.join(record.sig_name)}",
f"Segment duration: {record.sig_len / record.fs:.1f}s",
]
if meta:
lines += [
f"Input duration: {meta['input_duration_s']}s",
f"Ground truth: {meta['label']}",
f"AFib onset: {(meta['afib_onset_orig'] - meta['input_end_orig']) / meta['fs']:.0f}s after input ends",
]
return "\n".join(lines)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Event handlers: Classification
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def on_cls_example_select(name):
if not name:
return None, "", None
path = DEMO_DIR / name
record = read_ecg(str(path))
return plot_classification_ecg(record), get_classification_info(record, name), str(path)
def on_cls_upload(files):
if not files:
return None, "", None
paths = [Path(f.name) for f in files]
hea = next((p for p in paths if p.suffix == '.hea'), None)
if not hea:
return None, "Missing .hea file", None
record = read_ecg(str(hea))
return plot_classification_ecg(record), get_classification_info(record), str(hea.with_suffix(""))
def cls_load_first():
"""Pre-load first classification example on page open."""
examples = load_classification_examples()
if not examples:
return gr.update(), None, "", None
first = examples[0]
plot, info, path = on_cls_example_select(first)
return gr.update(value=first), plot, info, path
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Event handlers: Forecasting
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def on_fc_example_select(name):
if not name:
return None, "", None
path = FORECAST_DIR / name
record = read_ecg(str(path))
return plot_forecast_ecg(record, name), get_forecast_info(record, name), str(path)
def on_fc_upload(files):
if not files:
return None, "", None
paths = [Path(f.name) for f in files]
hea = next((p for p in paths if p.suffix == '.hea'), None)
if not hea:
return None, "Missing .hea file", None
record = read_ecg(str(hea))
return plot_forecast_ecg(record), get_forecast_info(record), str(hea.with_suffix(""))
def fc_load_first():
"""Pre-load first forecast example on page open."""
examples = load_forecast_examples()
if not examples:
return gr.update(), None, "", None
first = examples[0]
plot, info, path = on_fc_example_select(first)
return gr.update(value=first), plot, info, path
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Inference
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@spaces.GPU
def run_cls_inference(ecg_path, prompt):
if not ecg_path:
return "Select an ECG first."
t0 = time.time()
MODEL_BASE.move_to_device('cuda')
print(f" Base model .to(cuda): {time.time() - t0:.2f}s")
with torch.inference_mode():
args = SimpleNamespace(
mode='base', text=prompt, ecgs=[ecg_path], device='cuda',
ecg_configs=None, json=None, temperature=0.0, top_k=64,
top_p=0.95, min_p=0.0, max_new_tokens=512
)
output, _ = MODEL_BASE.run(args)
# Move back to CPU to free GPU for other models
MODEL_BASE.move_to_device('cpu')
torch.cuda.empty_cache()
print(f" Classification inference: {time.time() - t0:.2f}s")
return output
@spaces.GPU
def run_fc_inference(ecg_path, name):
if not ecg_path:
return "Select an ECG first."
# For known examples, use ecg_configs to slice the input portion
meta = FORECAST_META.get(name)
ecg_configs = None
if meta:
# The wfdb file starts at seg_start_orig. The model input is
# [input_start_orig, input_end_orig]. Convert to seconds relative
# to the start of the stored segment.
input_start_sec = (meta["input_start_orig"] - meta["seg_start_orig"]) // meta["fs"]
input_end_sec = (meta["input_end_orig"] - meta["seg_start_orig"]) // meta["fs"]
ecg_configs = [f"start:{input_start_sec};end:{input_end_sec}"]
t0 = time.time()
MODEL_FORECAST.move_to_device('cuda')
print(f" Forecast model .to(cuda): {time.time() - t0:.2f}s")
with torch.inference_mode():
args = SimpleNamespace(
mode='forecast', text=FORECAST_PROMPT, ecgs=[ecg_path], device='cuda',
ecg_configs=ecg_configs, json=None, temperature=0.0, top_k=64,
top_p=0.95, min_p=0.0, max_new_tokens=512
)
output, _ = MODEL_FORECAST.run(args)
# Move back to CPU to free GPU for other models
MODEL_FORECAST.move_to_device('cpu')
torch.cuda.empty_cache()
print(f" Forecast inference: {time.time() - t0:.2f}s")
return output
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# UI
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks(title="CAMEL ECG", css="""
.gr-button-primary { min-width: 120px; }
.prompt-display { background: var(--background-fill-secondary) !important; border: 1px solid var(--border-color-primary); border-radius: 8px; padding: 12px; font-size: 0.9em; }
.model-output textarea { background: rgba(100, 180, 255, 0.08) !important; color: var(--body-text-color) !important; border: 1px solid rgba(100, 180, 255, 0.25) !important; }
""") as demo:
gr.Markdown("# ๐Ÿช CAMEL ECG Model")
gr.Markdown("Cardiac AI Model for ECG analysis and rhythm forecasting.")
with gr.Tabs():
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Tab 1: Classification
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
with gr.Tab("๐Ÿฉบ Classification"):
cls_path = gr.State()
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Tabs():
with gr.Tab("Example"):
cls_dd = gr.Dropdown(
load_classification_examples(),
label="Select Example ECG",
)
with gr.Tab("Upload"):
cls_upload = gr.File(file_count="multiple", label="Upload .hea + .dat")
cls_prompt = gr.Dropdown(
CLASSIFICATION_PROMPTS,
value=CLASSIFICATION_PROMPTS[0],
label="Prompt",
)
cls_run_btn = gr.Button("Run Inference", variant="primary")
cls_out = gr.Textbox(label="Model Output", lines=6, interactive=False, elem_classes=["model-output"])
with gr.Column(scale=3):
cls_plot = gr.Plot(label="ECG")
cls_info = gr.Textbox(label="Ground-truth", lines=3, interactive=False)
cls_dd.change(on_cls_example_select, cls_dd, [cls_plot, cls_info, cls_path])
cls_upload.upload(on_cls_upload, cls_upload, [cls_plot, cls_info, cls_path])
cls_run_btn.click(run_cls_inference, [cls_path, cls_prompt], cls_out)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Tab 2: Forecasting
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
with gr.Tab("๐Ÿ”ฎ Forecasting"):
fc_path = gr.State()
fc_name = gr.State()
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Tabs():
with gr.Tab("Example"):
fc_dd = gr.Dropdown(
load_forecast_examples(),
label="Select Example ECG",
)
with gr.Tab("Upload"):
fc_upload = gr.File(file_count="multiple", label="Upload .hea + .dat")
gr.Markdown(
f"**Prompt (fixed):**\n\n{FORECAST_PROMPT}",
elem_classes=["prompt-display"],
)
fc_run_btn = gr.Button("Run Inference", variant="primary")
fc_out = gr.Textbox(label="Model Output", lines=5, interactive=False, elem_classes=["model-output"])
with gr.Column(scale=3):
fc_plot = gr.Plot(label="ECG Timeline")
fc_info = gr.Textbox(label="Info", lines=4, interactive=False)
def _fc_select_wrapper(name):
plot, info, path = on_fc_example_select(name)
return plot, info, path, name
def _fc_upload_wrapper(files):
plot, info, path = on_fc_upload(files)
return plot, info, path, ""
fc_dd.change(_fc_select_wrapper, fc_dd, [fc_plot, fc_info, fc_path, fc_name])
fc_upload.upload(_fc_upload_wrapper, fc_upload, [fc_plot, fc_info, fc_path, fc_name])
fc_run_btn.click(run_fc_inference, [fc_path, fc_name], fc_out)
# Pre-load first examples on page open
demo.load(cls_load_first, outputs=[cls_dd, cls_plot, cls_info, cls_path])
demo.load(fc_load_first, outputs=[fc_dd, fc_plot, fc_info, fc_path])
if __name__ == "__main__":
demo.launch()