import gradio as gr import numpy as np import pandas as pd from io import BytesIO import matplotlib.pyplot as plt from scipy.stats import skew, kurtosis import joblib import os SAMPLING_RATE = 125 WINDOW_SIZE = 125 SEQUENCE_LENGTH = 10 scaler = joblib.load("models/scaler/scaler.joblib") HYBRID_MODEL = [ ("LSTM", "RandomForest", "LSTM + Random Forest", "models/LSTM/rf_lstm_optuna_10122025_122323.joblib", "models/RandomForest/rf_lstm_optuna_10122025_123350.joblib"), ("LSTM", "SVM", "LSTM + SVM", "models/LSTM/SVM_LSTM_optuna_10122025_124318.joblib", "models/SVM/svm_LSTM_optuna_10122025_125033.joblib"), ("LSTM", "XGBOOST", "LSTM + XGBoost", "models/LSTM/xgboost_LSTM_optuna_10122025_113333.joblib", "models/XGBOOST/xgboost_LSTM_optuna_10122025_113701.joblib"), ("TRANSFORMER", "RandomForest", "Transformer + Random Forest", "models/TRANSFORMER/rf_transformer_optuna_10122025_045920.joblib", "models/RandomForest/rf_transformer_optuna_10122025_045552.joblib"), ("TRANSFORMER", "SVM", "Transformer + SVM", "models/TRANSFORMER/svm_transformer_optuna_10122025_050213.joblib", "models/SVM/svm_transformer_optuna_10122025_045226.joblib"), ("TRANSFORMER", "XGBOOST", "Transformer + XGBoost", "models/TRANSFORMER/xgboost_transformer_optuna_10122025_050346.joblib", "models/XGBOOST/xgboost_transformer_optuna_10122025_044957.joblib"), ] SINGLE_MODELS = [ ("Single LSTM", "models/LSTM/single_LSTM_default_10122025_135523.joblib"), ("Single Transformer", "models/TRANSFORMER/single_transformer_grid_10122025_115452.joblib"), ] MODEL_PAIRS = [m for m in HYBRID_MODEL if 'lstm' in m[3].lower() or 'lstm' in m[4].lower() or 'transformer' in m[3].lower() or 'transformer' in m[4].lower()] def ekstraksi_fitur_statistik(w): return np.array([skew(w), kurtosis(w), np.min(w), np.max(w), np.std(w)]) def ekstraksi_fitur_sinyal(ecg): n = len(ecg) // WINDOW_SIZE fitur = [] for i in range(n): w = ecg[i*WINDOW_SIZE:(i+1)*WINDOW_SIZE] fitur.append(ekstraksi_fitur_statistik(w)) return np.array(fitur) def buat_sequence(fitur): x = [] for i in range(len(fitur) - SEQUENCE_LENGTH + 1): x.append(fitur[i:i+SEQUENCE_LENGTH]) return np.array(x) def preprocessing_sinyal(ecg): fitur = ekstraksi_fitur_sinyal(ecg) fitur_scaled = scaler.transform(fitur) seq = buat_sequence(fitur_scaled) return seq def load_hybrid_models(p1, p2): m1 = joblib.load(p1) m2 = joblib.load(p2) return m1, m2 def analisis_sinyal(file, model_pair_label): single = next((s for s in SINGLE_MODELS if s[0] == model_pair_label), None) if single: model_path = os.path.join(os.path.dirname(__file__), single[1]) model = joblib.load(model_path) df = pd.read_csv(file.name) sinyal = df.values.flatten() seq = preprocessing_sinyal(sinyal) try: pred = model.predict(seq) if hasattr(model, 'predict_proba'): pred_proba = model.predict_proba(seq) label = int(np.argmax(pred_proba[0])) else: label = int(pred[0]) except Exception as e: return f"Failed to predict with single learning model: {e}", None else: selected = next((m for m in MODEL_PAIRS if m[2] == model_pair_label), None) if not selected: return "Model not found", None p1 = selected[3] p2 = selected[4] p1 = os.path.join(os.path.dirname(__file__), p1) p2 = os.path.join(os.path.dirname(__file__), p2) model_dl, model_clf = load_hybrid_models(p1, p2) df = pd.read_csv(file.name) sinyal = df.values.flatten() seq = preprocessing_sinyal(sinyal) try: fitur = model_dl.predict(seq) debug_info = f"Type model_dl: {type(model_dl)}, Output predict: {type(fitur)}, Shape: {getattr(fitur, 'shape', None)}" except Exception as e: return f"Failed to predict with feature extraction model: {e}", None n_features_model = getattr(model_clf, 'n_features_in_', None) if hasattr(fitur, 'shape') and n_features_model is not None and fitur.shape[0] >= n_features_model: fitur = fitur[-n_features_model:].reshape(1, n_features_model) if n_features_model is not None and (not hasattr(fitur, 'shape') or fitur.shape[1] != n_features_model): return debug_info + f"\nNumber of extracted features ({fitur.shape[1] if hasattr(fitur, 'shape') else '?'}) does not match the model's expected number ({n_features_model}). Ensure the feature extraction model and classifier are compatible.", None if hasattr(model_clf, "predict_proba"): pred = model_clf.predict_proba(fitur)[0] label = int(np.argmax(pred)) else: label = int(model_clf.predict(fitur)[0]) fig, ax = plt.subplots(figsize=(8, 3)) ax.plot(sinyal, label="Raw ECG", color="#2196f3", linewidth=1) if len(sinyal) > 25: ma = pd.Series(sinyal).rolling(window=25, min_periods=1, center=True).mean() ax.plot(ma, label="Moving Average", color="#ff9800", linewidth=2, alpha=0.7) ax.set_title("ECG Signal (Raw & Smoothed)") ax.set_xlabel("Sample") ax.set_ylabel("Amplitude") ax.legend() ax.grid(True, linestyle='--', alpha=0.5) fig.tight_layout() buf = BytesIO() fig.savefig(buf, format="png") buf.seek(0) import PIL.Image img = PIL.Image.open(buf) return str(label), img css = """ body {background-color: #181818; color: #f5f5f5;} .gradio-container, .gradio-app {background-color: #181818 !important;} #title {text-align:center; font-size:32px; font-weight:700; margin-bottom:20px; color:#f5f5f5;} #subtitle {text-align:center; font-size:18px; margin-bottom:40px; color:#bbbbbb;} input, select, textarea, .gr-button, .gr-input, .gr-textbox, .gr-dropdown, .gr-file, .gr-image { background-color: #232323 !important; color: #f5f5f5 !important; border-color: #444 !important; } .gr-button {border-radius: 6px;} .license-box { border: 2px solid #fff; border-radius: 10px; padding: 18px; margin-top: 18px; background: #111; color: #fff; } """ with gr.Blocks() as demo: gr.HTML('
Atrial Fibrillation Detection
') gr.HTML('
Analyze ECG data for Atrial Fibrillation presence
') file_upload = gr.File(label="Upload Dataset/Signal (CSV)", file_types=[".csv"]) shared_file_path = gr.State() clear_btn = gr.Button("Clear File") def store_file(file): return file.name if file is not None else None file_upload.change(store_file, inputs=file_upload, outputs=shared_file_path) def clear_file(): return None, None clear_btn.click(clear_file, inputs=None, outputs=[file_upload, shared_file_path]) with gr.Tab("Dataset Info"): gr.Markdown(""" ## Dataset Files The dataset files below are the same as those used for model testing in this application. Please download using the buttons below: """) with gr.Row(): gr.File(value="Data/mimic_perform_af_001_data.csv", label="Download Atrial Fibrillation Data", interactive=False) gr.File(value="Data/mimic_perform_non_af_001_data.csv", label="Download Non-Atrial Fibrillation Data", interactive=False) gr.Markdown( """
Dataset License
This dataset is licensed under the Open Data Commons Open Database License v1.0 (ODbL 1.0 license).
Further details: ODbL 1.0

This dataset is derived from the MIMIC III Waveform Database:
Moody, B., Moody, G., Villarroel, M., Clifford, G. D., & Silva, I. (2020). MIMIC-III Waveform Database (version 1.0). PhysioNet. https://doi.org/10.13026/c2607m

The MIMIC III Waveform Database is licensed under the ODbL 1.0 license.

The MIMIC-III database is described in:
Johnson, A. E. W., Pollard, T. J., Shen, L., Lehman, L. H., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3, 160035. https://doi.org/10.1038/sdata.2016.35

It is available on PhysioNet: https://physionet.org/
Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220.

The following annotations of AF and non-AF were used to create the dataset:
Bashar, Syed Khairul (2020): Atrial Fibrillation annotations of electrocardiogram from MIMIC III matched subset. figshare. Dataset. https://doi.org/10.6084/m9.figshare.12149091.v1

Bashar, S.K., Ding, E., Walkey, A.J., McManus, D.D. and Chon, K.H., 2019. Noise Detection in Electrocardiogram Signals for Intensive Care Unit Patients. IEEE Access, 7, pp.88357-88368. https://doi.org/10.1109/ACCESS.2019.2926199

This annotation information is reproduced under the terms of the CC BY 4.0 licence
""",elem_id=None) with gr.Tab("Analyze Dataset"): info_ds = gr.Textbox(label="Dataset Info", interactive=False, lines=5) preview_ds = gr.Dataframe(label="Data Preview", interactive=False) plot_ds = gr.Plot(label="ECG Signal Plot") btn_ds = gr.Button("Analyze Dataset") def analyze_dataset(file_path): import pandas as pd import matplotlib.pyplot as plt if file_path is None: return "No file uploaded. Please upload a CSV file first.", None, None try: df = pd.read_csv(file_path) except Exception as e: return f"Failed to read file. Make sure the file is a valid CSV. Error: {e}", None, None info_lines = [] info_lines.append(f"Shape: {df.shape}") info_lines.append(f"Columns: {list(df.columns)}") info_lines.append(f"Missing: {df.isnull().sum().to_dict()}") duration = None sampling_rate = 125 n_samples = len(df) duration = n_samples / sampling_rate info_lines.append(f"Sampling rate: {sampling_rate} Hz") info_lines.append(f"Data duration: {duration:.2f} seconds ({duration/60:.2f} minutes)") preview = df.head(10) fig = None ecg_col = None for col in df.columns: if 'ecg' in col.lower(): ecg_col = col break if ecg_col is None: return "The uploaded CSV does not contain an 'ecg' column. Please upload a CSV file with an 'ecg' feature/column.", preview, None if not pd.api.types.is_numeric_dtype(df[ecg_col]): return f"The selected signal column ('{ecg_col}') is not numeric. Please upload a valid ECG CSV.", preview, None plot_samples = min(sampling_rate*10, len(df)) try: fig, ax = plt.subplots() ax.plot(df[ecg_col].values[:plot_samples]) ax.set_title(f"First 10 Seconds Signal Plot: {ecg_col}") ax.set_xlabel("Sample") ax.set_ylabel("Amplitude") except Exception as e: return f"Failed to plot ECG signal: {e}", preview, None return "\n".join(info_lines), preview, fig btn_ds.click(analyze_dataset, inputs=[shared_file_path], outputs=[info_ds, preview_ds, plot_ds]) with gr.Tab("Analyze Model"): all_model_labels = [m[2] for m in MODEL_PAIRS] + [s[0] for s in SINGLE_MODELS] pilih_model = gr.Dropdown(all_model_labels, label="Select Model", value=all_model_labels[0]) hasil = gr.Textbox(label="Prediction Result", interactive=False) tombol = gr.Button("Predict") def handle_predict(file_path, model_label): if file_path is None or str(file_path).strip() == "": return "No file uploaded. Please upload a CSV file first." if not model_label: return "No model selected. Please select a model." try: import pandas as pd df = pd.read_csv(file_path) ecg_col = None for col in df.columns: if 'ecg' in col.lower(): ecg_col = col break if ecg_col is None: return "The uploaded CSV does not contain an 'ecg' column. Please upload a CSV file with an 'ecg' feature/column for prediction." if not pd.api.types.is_numeric_dtype(df[ecg_col]): return f"The selected signal column ('{ecg_col}') is not numeric. Please upload a valid ECG CSV." class DummyFile: def __init__(self, name): self.name = name dummy_file = DummyFile(file_path) result, _ = analisis_sinyal(dummy_file, model_label) if str(result).strip() == '0': return 'Non-AF' elif str(result).strip() == '1': return 'AF' else: return str(result) except Exception as e: return f"Prediction failed: {e}" tombol.click(handle_predict, inputs=[shared_file_path, pilih_model], outputs=[hasil]) def clear_file(): return None, None, None clear_btn.click(clear_file, inputs=None, outputs=[file_upload, shared_file_path, hasil]) demo.launch(css=css)