Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from scipy.signal import butter, filtfilt | |
| import pywt | |
| import tempfile | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download( | |
| repo_id="Mathivani/arrhythmia-private-model", | |
| filename="final_model.keras", | |
| token=os.getenv("HUGGINGFACEHUB_TOKEN") # ✅ This fetches your secret securely | |
| ) | |
| model = load_model(model_path) | |
| # === Label & Color Maps === | |
| label_map = {0: "Bradycardia", 1: "Tachycardia", 2: "VFib", 3: "VTach", 4: "Normal"} | |
| color_map = { | |
| "Bradycardia": "blue", | |
| "Tachycardia": "orange", | |
| "VFib": "red", | |
| "VTach": "purple", | |
| "Normal": "green" | |
| } | |
| # === Preprocess PPG === | |
| def preprocess_signal(sig, fs): | |
| b, a = butter(4, [0.5, 8], btype='bandpass', fs=fs) | |
| filtered = filtfilt(b, a, sig) | |
| smoothed = np.convolve(filtered, np.ones(5)/5, mode='same') | |
| coeffs = pywt.wavedec(smoothed, 'db4', level=4) | |
| coeffs[-1] = np.zeros_like(coeffs[-1]) | |
| cleaned = pywt.waverec(coeffs, 'db4') | |
| norm = (cleaned - np.median(cleaned)) / (np.max(cleaned) - np.min(cleaned) + 1e-8) | |
| return norm | |
| # === Segment PPG === | |
| def segment_signal(ppg, fs, win_sec=20, overlap_sec=10): | |
| win_len = fs * win_sec | |
| stride = fs * overlap_sec | |
| segments = [] | |
| for start in range(0, len(ppg) - win_len + 1, stride): | |
| end = start + win_len | |
| seg_ppg = preprocess_signal(ppg[start:end], fs) | |
| if len(seg_ppg) >= 2500: | |
| segments.append(seg_ppg[:2500]) | |
| return segments | |
| # === Analyze Single File === | |
| def analyze_file(file, show_plot): | |
| try: | |
| df = pd.read_csv(file.name) | |
| df.columns = df.columns.str.strip().str.upper() | |
| ppg_col = next((c for c in ["PPG", "PLETH"] if c in df.columns), None) | |
| if not ppg_col: | |
| return None, None, None, f"{os.path.basename(file.name)}: PPG column missing" | |
| fs = 250 | |
| if "TIME" in df.columns: | |
| diffs = df["TIME"].diff().dropna().values | |
| if len(diffs) > 0: | |
| fs = int(round(1 / np.median(diffs))) | |
| ppg = df[ppg_col].values | |
| segments = segment_signal(ppg, fs) | |
| if not segments: | |
| return None, None, None, f"{os.path.basename(file.name)}: Insufficient signal duration" | |
| ppg_input = np.array(segments)[:, :, np.newaxis] | |
| preds = model.predict(ppg_input, verbose=0) | |
| pred_classes = np.argmax(preds, axis=1) | |
| counts = pd.Series(pred_classes).value_counts() | |
| majority_class = counts.idxmax() | |
| confidence = round(100 * counts.max() / len(pred_classes), 2) | |
| final_label = label_map[majority_class] | |
| summary_row = { | |
| "File Name": os.path.basename(file.name), | |
| "Predicted Class": final_label, | |
| "Confidence (%)": confidence, | |
| "Segments Used": len(pred_classes) | |
| } | |
| segment_df = pd.DataFrame({ | |
| "Segment #": list(range(1, len(pred_classes)+1)), | |
| "Predicted Class": [label_map[c] for c in pred_classes], | |
| "Confidence": preds.max(axis=1) | |
| }) | |
| segment_df.insert(0, "File", os.path.basename(file.name)) | |
| fig = None | |
| if show_plot: | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| ax.plot(ppg[:fs*10], color='blue') | |
| ax.set_title(f"PPG Signal (First 10s) - {os.path.basename(file.name)}") | |
| plt.tight_layout() | |
| return summary_row, segment_df, fig, None | |
| except Exception as e: | |
| return None, None, None, f"{os.path.basename(file.name)}: Error - {e}" | |
| # === Batch Predict === | |
| def batch_predict(files, show_plot): | |
| summary_rows = [] | |
| all_segments = [] | |
| all_image_paths = [] | |
| errors = [] | |
| for file in files: | |
| summary_row, df_segment, fig, err = analyze_file(file, show_plot) | |
| if summary_row: | |
| summary_rows.append(summary_row) | |
| if df_segment is not None: | |
| all_segments.append(df_segment) | |
| if fig: | |
| tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name | |
| fig.savefig(tmp_path) | |
| plt.close(fig) | |
| all_image_paths.append(tmp_path) | |
| if err: | |
| errors.append(err) | |
| summary_df = pd.DataFrame(summary_rows) | |
| segments_df = pd.concat(all_segments, ignore_index=True) if all_segments else None | |
| temp_csv = None | |
| if segments_df is not None: | |
| temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| segments_df.to_csv(temp_csv.name, index=False) | |
| return summary_df, segments_df, temp_csv.name if temp_csv else None, all_image_paths if show_plot else None | |
| # === Gradio UI === | |
| gr.Interface( | |
| fn=batch_predict, | |
| inputs=[ | |
| gr.File(file_types=[".csv"], file_count="multiple", label="Upload PPG CSV(s)"), | |
| gr.Checkbox(label="Show signal plot preview (first 10 seconds)", value=True) | |
| ], | |
| outputs=[ | |
| gr.Dataframe(label="🧠 Final Prediction Summary"), | |
| gr.Dataframe(label="📄 Segment-wise Predictions"), | |
| gr.File(label="⬇️ Download Segment Report"), | |
| gr.Gallery(label="📈 Signal Plots (Optional)", show_label=True, visible=True) | |
| ], | |
| title="🩺 Arrhythmia Detection using PPG (CNN+LSTM)", | |
| description="Upload CSV files with Time and PPG/Pleth columns. CNN + LSTM model (PPG-only). Displays predictions and plots.", | |
| allow_flagging="never" | |
| ).launch() | |