Mathivani's picture
Upload app.py with huggingface_hub
82f2094 verified
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from scipy.signal import butter, filtfilt
import pywt
import tempfile
import os
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id="Mathivani/arrhythmia-private-model",
filename="final_model.keras",
token=os.getenv("HUGGINGFACEHUB_TOKEN") # ✅ This fetches your secret securely
)
model = load_model(model_path)
# === Label & Color Maps ===
label_map = {0: "Bradycardia", 1: "Tachycardia", 2: "VFib", 3: "VTach", 4: "Normal"}
color_map = {
"Bradycardia": "blue",
"Tachycardia": "orange",
"VFib": "red",
"VTach": "purple",
"Normal": "green"
}
# === Preprocess PPG ===
def preprocess_signal(sig, fs):
b, a = butter(4, [0.5, 8], btype='bandpass', fs=fs)
filtered = filtfilt(b, a, sig)
smoothed = np.convolve(filtered, np.ones(5)/5, mode='same')
coeffs = pywt.wavedec(smoothed, 'db4', level=4)
coeffs[-1] = np.zeros_like(coeffs[-1])
cleaned = pywt.waverec(coeffs, 'db4')
norm = (cleaned - np.median(cleaned)) / (np.max(cleaned) - np.min(cleaned) + 1e-8)
return norm
# === Segment PPG ===
def segment_signal(ppg, fs, win_sec=20, overlap_sec=10):
win_len = fs * win_sec
stride = fs * overlap_sec
segments = []
for start in range(0, len(ppg) - win_len + 1, stride):
end = start + win_len
seg_ppg = preprocess_signal(ppg[start:end], fs)
if len(seg_ppg) >= 2500:
segments.append(seg_ppg[:2500])
return segments
# === Analyze Single File ===
def analyze_file(file, show_plot):
try:
df = pd.read_csv(file.name)
df.columns = df.columns.str.strip().str.upper()
ppg_col = next((c for c in ["PPG", "PLETH"] if c in df.columns), None)
if not ppg_col:
return None, None, None, f"{os.path.basename(file.name)}: PPG column missing"
fs = 250
if "TIME" in df.columns:
diffs = df["TIME"].diff().dropna().values
if len(diffs) > 0:
fs = int(round(1 / np.median(diffs)))
ppg = df[ppg_col].values
segments = segment_signal(ppg, fs)
if not segments:
return None, None, None, f"{os.path.basename(file.name)}: Insufficient signal duration"
ppg_input = np.array(segments)[:, :, np.newaxis]
preds = model.predict(ppg_input, verbose=0)
pred_classes = np.argmax(preds, axis=1)
counts = pd.Series(pred_classes).value_counts()
majority_class = counts.idxmax()
confidence = round(100 * counts.max() / len(pred_classes), 2)
final_label = label_map[majority_class]
summary_row = {
"File Name": os.path.basename(file.name),
"Predicted Class": final_label,
"Confidence (%)": confidence,
"Segments Used": len(pred_classes)
}
segment_df = pd.DataFrame({
"Segment #": list(range(1, len(pred_classes)+1)),
"Predicted Class": [label_map[c] for c in pred_classes],
"Confidence": preds.max(axis=1)
})
segment_df.insert(0, "File", os.path.basename(file.name))
fig = None
if show_plot:
fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(ppg[:fs*10], color='blue')
ax.set_title(f"PPG Signal (First 10s) - {os.path.basename(file.name)}")
plt.tight_layout()
return summary_row, segment_df, fig, None
except Exception as e:
return None, None, None, f"{os.path.basename(file.name)}: Error - {e}"
# === Batch Predict ===
def batch_predict(files, show_plot):
summary_rows = []
all_segments = []
all_image_paths = []
errors = []
for file in files:
summary_row, df_segment, fig, err = analyze_file(file, show_plot)
if summary_row:
summary_rows.append(summary_row)
if df_segment is not None:
all_segments.append(df_segment)
if fig:
tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
fig.savefig(tmp_path)
plt.close(fig)
all_image_paths.append(tmp_path)
if err:
errors.append(err)
summary_df = pd.DataFrame(summary_rows)
segments_df = pd.concat(all_segments, ignore_index=True) if all_segments else None
temp_csv = None
if segments_df is not None:
temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
segments_df.to_csv(temp_csv.name, index=False)
return summary_df, segments_df, temp_csv.name if temp_csv else None, all_image_paths if show_plot else None
# === Gradio UI ===
gr.Interface(
fn=batch_predict,
inputs=[
gr.File(file_types=[".csv"], file_count="multiple", label="Upload PPG CSV(s)"),
gr.Checkbox(label="Show signal plot preview (first 10 seconds)", value=True)
],
outputs=[
gr.Dataframe(label="🧠 Final Prediction Summary"),
gr.Dataframe(label="📄 Segment-wise Predictions"),
gr.File(label="⬇️ Download Segment Report"),
gr.Gallery(label="📈 Signal Plots (Optional)", show_label=True, visible=True)
],
title="🩺 Arrhythmia Detection using PPG (CNN+LSTM)",
description="Upload CSV files with Time and PPG/Pleth columns. CNN + LSTM model (PPG-only). Displays predictions and plots.",
allow_flagging="never"
).launch()