| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import wfdb |
| import tensorflow as tf |
| from scipy import signal |
| import os |
| import subprocess |
| import shutil |
| import requests |
| import zipfile |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise ValueError("HF_TOKEN not found. Please set it in the Space's environment variables.") |
|
|
| |
| REPO_URL = "https://github.com/AutoECG/Automated-ECG-Interpretation.git" |
| REPO_DIR = "Automated-ECG-Interpretation" |
| DATASET_URL = "https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip" |
| DATASET_DIR = "ptb-xl" |
| PERSISTENT_DIR = "/data" |
|
|
| |
| def ensure_persistent_dir(): |
| if not os.path.exists(PERSISTENT_DIR): |
| os.makedirs(PERSISTENT_DIR, exist_ok=True) |
|
|
| |
| def clone_repository(): |
| if not os.path.exists(REPO_DIR): |
| print("Cloning repository...") |
| try: |
| subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) |
| print("Repository cloned successfully.") |
| except subprocess.CalledProcessError as e: |
| print(f"Error cloning repository: {e}") |
| else: |
| print("Repository already cloned.") |
|
|
| |
| def download_and_extract_dataset(): |
| ensure_persistent_dir() |
| zip_path = os.path.join(PERSISTENT_DIR, "ptb-xl.zip") |
| extract_path = os.path.join(PERSISTENT_DIR, DATASET_DIR) |
| if not os.path.exists(extract_path): |
| print("Downloading PTB-XL dataset...") |
| response = requests.get(DATASET_URL, stream=True) |
| with open(zip_path, "wb") as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| print("Extracting dataset...") |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: |
| zip_ref.extractall(PERSISTENT_DIR) |
| os.remove(zip_path) |
| print("Dataset extracted successfully.") |
| else: |
| print("Dataset already extracted.") |
|
|
| |
| clone_repository() |
| download_and_extract_dataset() |
|
|
| |
| MODEL_FILENAME = "model.h5" |
| MODEL_PATH = os.path.join(REPO_DIR, MODEL_FILENAME) |
| PERSISTENT_MODEL_PATH = os.path.join(PERSISTENT_DIR, MODEL_FILENAME) |
|
|
| if not os.path.exists(PERSISTENT_MODEL_PATH): |
| if os.path.exists(MODEL_PATH): |
| shutil.copy(MODEL_PATH, PERSISTENT_MODEL_PATH) |
| else: |
| raise FileNotFoundError( |
| f"Model file not found at {MODEL_PATH}. Please ensure it's in the repository or upload it manually." |
| ) |
|
|
| model = tf.keras.models.load_model(PERSISTENT_MODEL_PATH) |
|
|
| |
| def preprocess_ecg(file_path): |
| record = wfdb.rdrecord(file_path.replace(".dat", "")) |
| ecg_signal = record.p_signal[:, 0] |
| target_fs = 360 |
| num_samples = int(len(ecg_signal) * target_fs / record.fs) |
| ecg_resampled = signal.resample(ecg_signal, num_samples) |
| ecg_normalized = (ecg_resampled - np.mean(ecg_resampled)) / np.std(ecg_resampled) |
| if len(ecg_normalized) < 3600: |
| ecg_normalized = np.pad(ecg_normalized, (0, 3600 - len(ecg_normalized)), "constant") |
| else: |
| ecg_normalized = ecg_normalized[:3600] |
| ecg_input = ecg_normalized.reshape(1, 3600, 1) |
| return ecg_input |
|
|
| |
| def predict_ecg(file=None, dataset_file=None): |
| if file: |
| file_path = file.name |
| elif dataset_file: |
| file_path = os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500", dataset_file) |
| else: |
| return "Please upload a file or select a dataset sample." |
| |
| ecg_data = preprocess_ecg(file_path) |
| prediction = model.predict(ecg_data) |
| label = "Abnormal" if prediction[0][0] > 0.5 else "Normal" |
| confidence = float(prediction[0][0]) if label == "Abnormal" else float(1 - prediction[0][0]) |
| return f"Prediction: {label}\nConfidence: {confidence:.2%}" |
|
|
| |
| dataset_files = [] |
| if os.path.exists(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): |
| for root, _, files in os.walk(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): |
| for file in files: |
| if file.endswith(".dat"): |
| dataset_files.append(os.path.relpath(os.path.join(root, file), os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500"))) |
|
|
| |
| interface = gr.Interface( |
| fn=predict_ecg, |
| inputs=[ |
| gr.File(label="Upload ECG File (.dat format)"), |
| gr.Dropdown(choices=dataset_files, label="Or Select a PTB-XL Sample") |
| ], |
| outputs=gr.Textbox(label="ECG Interpretation"), |
| title="Automated ECG Interpretation", |
| description="Upload an ECG file (.dat) or select a sample from the PTB-XL dataset for automated interpretation." |
| ) |
|
|
| |
| interface.launch() |