import gradio as gr
import numpy as np
import pandas as pd
from io import BytesIO
import matplotlib.pyplot as plt
from scipy.stats import skew, kurtosis
import joblib
import os
SAMPLING_RATE = 125
WINDOW_SIZE = 125
SEQUENCE_LENGTH = 10
scaler = joblib.load("models/scaler/scaler.joblib")
HYBRID_MODEL = [
("LSTM", "RandomForest", "LSTM + Random Forest", "models/LSTM/rf_lstm_optuna_10122025_122323.joblib", "models/RandomForest/rf_lstm_optuna_10122025_123350.joblib"),
("LSTM", "SVM", "LSTM + SVM", "models/LSTM/SVM_LSTM_optuna_10122025_124318.joblib", "models/SVM/svm_LSTM_optuna_10122025_125033.joblib"),
("LSTM", "XGBOOST", "LSTM + XGBoost", "models/LSTM/xgboost_LSTM_optuna_10122025_113333.joblib", "models/XGBOOST/xgboost_LSTM_optuna_10122025_113701.joblib"),
("TRANSFORMER", "RandomForest", "Transformer + Random Forest", "models/TRANSFORMER/rf_transformer_optuna_10122025_045920.joblib", "models/RandomForest/rf_transformer_optuna_10122025_045552.joblib"),
("TRANSFORMER", "SVM", "Transformer + SVM", "models/TRANSFORMER/svm_transformer_optuna_10122025_050213.joblib", "models/SVM/svm_transformer_optuna_10122025_045226.joblib"),
("TRANSFORMER", "XGBOOST", "Transformer + XGBoost", "models/TRANSFORMER/xgboost_transformer_optuna_10122025_050346.joblib", "models/XGBOOST/xgboost_transformer_optuna_10122025_044957.joblib"),
]
SINGLE_MODELS = [
("Single LSTM", "models/LSTM/single_LSTM_default_10122025_135523.joblib"),
("Single Transformer", "models/TRANSFORMER/single_transformer_grid_10122025_115452.joblib"),
]
MODEL_PAIRS = [m for m in HYBRID_MODEL if 'lstm' in m[3].lower() or 'lstm' in m[4].lower() or 'transformer' in m[3].lower() or 'transformer' in m[4].lower()]
def ekstraksi_fitur_statistik(w):
return np.array([skew(w), kurtosis(w), np.min(w), np.max(w), np.std(w)])
def ekstraksi_fitur_sinyal(ecg):
n = len(ecg) // WINDOW_SIZE
fitur = []
for i in range(n):
w = ecg[i*WINDOW_SIZE:(i+1)*WINDOW_SIZE]
fitur.append(ekstraksi_fitur_statistik(w))
return np.array(fitur)
def buat_sequence(fitur):
x = []
for i in range(len(fitur) - SEQUENCE_LENGTH + 1):
x.append(fitur[i:i+SEQUENCE_LENGTH])
return np.array(x)
def preprocessing_sinyal(ecg):
fitur = ekstraksi_fitur_sinyal(ecg)
fitur_scaled = scaler.transform(fitur)
seq = buat_sequence(fitur_scaled)
return seq
def load_hybrid_models(p1, p2):
m1 = joblib.load(p1)
m2 = joblib.load(p2)
return m1, m2
def analisis_sinyal(file, model_pair_label):
single = next((s for s in SINGLE_MODELS if s[0] == model_pair_label), None)
if single:
model_path = os.path.join(os.path.dirname(__file__), single[1])
model = joblib.load(model_path)
df = pd.read_csv(file.name)
sinyal = df.values.flatten()
seq = preprocessing_sinyal(sinyal)
try:
pred = model.predict(seq)
if hasattr(model, 'predict_proba'):
pred_proba = model.predict_proba(seq)
label = int(np.argmax(pred_proba[0]))
else:
label = int(pred[0])
except Exception as e:
return f"Failed to predict with single learning model: {e}", None
else:
selected = next((m for m in MODEL_PAIRS if m[2] == model_pair_label), None)
if not selected:
return "Model not found", None
p1 = selected[3]
p2 = selected[4]
p1 = os.path.join(os.path.dirname(__file__), p1)
p2 = os.path.join(os.path.dirname(__file__), p2)
model_dl, model_clf = load_hybrid_models(p1, p2)
df = pd.read_csv(file.name)
sinyal = df.values.flatten()
seq = preprocessing_sinyal(sinyal)
try:
fitur = model_dl.predict(seq)
debug_info = f"Type model_dl: {type(model_dl)}, Output predict: {type(fitur)}, Shape: {getattr(fitur, 'shape', None)}"
except Exception as e:
return f"Failed to predict with feature extraction model: {e}", None
n_features_model = getattr(model_clf, 'n_features_in_', None)
if hasattr(fitur, 'shape') and n_features_model is not None and fitur.shape[0] >= n_features_model:
fitur = fitur[-n_features_model:].reshape(1, n_features_model)
if n_features_model is not None and (not hasattr(fitur, 'shape') or fitur.shape[1] != n_features_model):
return debug_info + f"\nNumber of extracted features ({fitur.shape[1] if hasattr(fitur, 'shape') else '?'}) does not match the model's expected number ({n_features_model}). Ensure the feature extraction model and classifier are compatible.", None
if hasattr(model_clf, "predict_proba"):
pred = model_clf.predict_proba(fitur)[0]
label = int(np.argmax(pred))
else:
label = int(model_clf.predict(fitur)[0])
fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(sinyal, label="Raw ECG", color="#2196f3", linewidth=1)
if len(sinyal) > 25:
ma = pd.Series(sinyal).rolling(window=25, min_periods=1, center=True).mean()
ax.plot(ma, label="Moving Average", color="#ff9800", linewidth=2, alpha=0.7)
ax.set_title("ECG Signal (Raw & Smoothed)")
ax.set_xlabel("Sample")
ax.set_ylabel("Amplitude")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.5)
fig.tight_layout()
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
import PIL.Image
img = PIL.Image.open(buf)
return str(label), img
css = """
body {background-color: #181818; color: #f5f5f5;}
.gradio-container, .gradio-app {background-color: #181818 !important;}
#title {text-align:center; font-size:32px; font-weight:700; margin-bottom:20px; color:#f5f5f5;}
#subtitle {text-align:center; font-size:18px; margin-bottom:40px; color:#bbbbbb;}
input, select, textarea, .gr-button, .gr-input, .gr-textbox, .gr-dropdown, .gr-file, .gr-image {
background-color: #232323 !important;
color: #f5f5f5 !important;
border-color: #444 !important;
}
.gr-button {border-radius: 6px;}
.license-box {
border: 2px solid #fff;
border-radius: 10px;
padding: 18px;
margin-top: 18px;
background: #111;
color: #fff;
}
"""
with gr.Blocks() as demo:
gr.HTML('
Atrial Fibrillation Detection
')
gr.HTML('Analyze ECG data for Atrial Fibrillation presence
')
file_upload = gr.File(label="Upload Dataset/Signal (CSV)", file_types=[".csv"])
shared_file_path = gr.State()
clear_btn = gr.Button("Clear File")
def store_file(file):
return file.name if file is not None else None
file_upload.change(store_file, inputs=file_upload, outputs=shared_file_path)
def clear_file():
return None, None
clear_btn.click(clear_file, inputs=None, outputs=[file_upload, shared_file_path])
with gr.Tab("Dataset Info"):
gr.Markdown("""
## Dataset Files
The dataset files below are the same as those used for model testing in this application. Please download using the buttons below:
""")
with gr.Row():
gr.File(value="Data/mimic_perform_af_001_data.csv", label="Download Atrial Fibrillation Data", interactive=False)
gr.File(value="Data/mimic_perform_non_af_001_data.csv", label="Download Non-Atrial Fibrillation Data", interactive=False)
gr.Markdown(
"""
Dataset License
This dataset is licensed under the Open Data Commons Open Database License v1.0 (ODbL 1.0 license).
Further details:
ODbL 1.0
This dataset is derived from the MIMIC III Waveform Database:
Moody, B., Moody, G., Villarroel, M., Clifford, G. D., & Silva, I. (2020). MIMIC-III Waveform Database (version 1.0). PhysioNet.
https://doi.org/10.13026/c2607m
The MIMIC III Waveform Database is licensed under the ODbL 1.0 license.
The MIMIC-III database is described in:
Johnson, A. E. W., Pollard, T. J., Shen, L., Lehman, L. H., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3, 160035.
https://doi.org/10.1038/sdata.2016.35
It is available on PhysioNet:
https://physionet.org/
Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220.
The following annotations of AF and non-AF were used to create the dataset:
Bashar, Syed Khairul (2020): Atrial Fibrillation annotations of electrocardiogram from MIMIC III matched subset. figshare. Dataset.
https://doi.org/10.6084/m9.figshare.12149091.v1
Bashar, S.K., Ding, E., Walkey, A.J., McManus, D.D. and Chon, K.H., 2019. Noise Detection in Electrocardiogram Signals for Intensive Care Unit Patients. IEEE Access, 7, pp.88357-88368.
https://doi.org/10.1109/ACCESS.2019.2926199
This annotation information is reproduced under the terms of the
CC BY 4.0 licence
""",elem_id=None)
with gr.Tab("Analyze Dataset"):
info_ds = gr.Textbox(label="Dataset Info", interactive=False, lines=5)
preview_ds = gr.Dataframe(label="Data Preview", interactive=False)
plot_ds = gr.Plot(label="ECG Signal Plot")
btn_ds = gr.Button("Analyze Dataset")
def analyze_dataset(file_path):
import pandas as pd
import matplotlib.pyplot as plt
if file_path is None:
return "No file uploaded. Please upload a CSV file first.", None, None
try:
df = pd.read_csv(file_path)
except Exception as e:
return f"Failed to read file. Make sure the file is a valid CSV. Error: {e}", None, None
info_lines = []
info_lines.append(f"Shape: {df.shape}")
info_lines.append(f"Columns: {list(df.columns)}")
info_lines.append(f"Missing: {df.isnull().sum().to_dict()}")
duration = None
sampling_rate = 125
n_samples = len(df)
duration = n_samples / sampling_rate
info_lines.append(f"Sampling rate: {sampling_rate} Hz")
info_lines.append(f"Data duration: {duration:.2f} seconds ({duration/60:.2f} minutes)")
preview = df.head(10)
fig = None
ecg_col = None
for col in df.columns:
if 'ecg' in col.lower():
ecg_col = col
break
if ecg_col is None:
return "The uploaded CSV does not contain an 'ecg' column. Please upload a CSV file with an 'ecg' feature/column.", preview, None
if not pd.api.types.is_numeric_dtype(df[ecg_col]):
return f"The selected signal column ('{ecg_col}') is not numeric. Please upload a valid ECG CSV.", preview, None
plot_samples = min(sampling_rate*10, len(df))
try:
fig, ax = plt.subplots()
ax.plot(df[ecg_col].values[:plot_samples])
ax.set_title(f"First 10 Seconds Signal Plot: {ecg_col}")
ax.set_xlabel("Sample")
ax.set_ylabel("Amplitude")
except Exception as e:
return f"Failed to plot ECG signal: {e}", preview, None
return "\n".join(info_lines), preview, fig
btn_ds.click(analyze_dataset, inputs=[shared_file_path], outputs=[info_ds, preview_ds, plot_ds])
with gr.Tab("Analyze Model"):
all_model_labels = [m[2] for m in MODEL_PAIRS] + [s[0] for s in SINGLE_MODELS]
pilih_model = gr.Dropdown(all_model_labels, label="Select Model", value=all_model_labels[0])
hasil = gr.Textbox(label="Prediction Result", interactive=False)
tombol = gr.Button("Predict")
def handle_predict(file_path, model_label):
if file_path is None or str(file_path).strip() == "":
return "No file uploaded. Please upload a CSV file first."
if not model_label:
return "No model selected. Please select a model."
try:
import pandas as pd
df = pd.read_csv(file_path)
ecg_col = None
for col in df.columns:
if 'ecg' in col.lower():
ecg_col = col
break
if ecg_col is None:
return "The uploaded CSV does not contain an 'ecg' column. Please upload a CSV file with an 'ecg' feature/column for prediction."
if not pd.api.types.is_numeric_dtype(df[ecg_col]):
return f"The selected signal column ('{ecg_col}') is not numeric. Please upload a valid ECG CSV."
class DummyFile:
def __init__(self, name):
self.name = name
dummy_file = DummyFile(file_path)
result, _ = analisis_sinyal(dummy_file, model_label)
if str(result).strip() == '0':
return 'Non-AF'
elif str(result).strip() == '1':
return 'AF'
else:
return str(result)
except Exception as e:
return f"Prediction failed: {e}"
tombol.click(handle_predict, inputs=[shared_file_path, pilih_model], outputs=[hasil])
def clear_file():
return None, None, None
clear_btn.click(clear_file, inputs=None, outputs=[file_upload, shared_file_path, hasil])
demo.launch(css=css)