wesad / app.py
Mrinal007's picture
Update app.py
79d7a8b verified
import numpy as np
import gradio as gr
from tensorflow.keras.models import load_model, Model
from hpelm import ELM
import joblib
# ๐Ÿ”„ Load scaler and ELM model
scaler = joblib.load("elm_scaler1.pkl")
elm = ELM(128, 1, classification='c')
elm.load("elm_model1.txt")
# โš™๏ธ Load trained MobileNet1D and create feature extractor
mobilenet = load_model("physio_model2.h5")
feature_extractor = Model(inputs=mobilenet.input, outputs=mobilenet.get_layer("penultimate_dense").output)
def preprocess_signal(ecg, eda, temp):
"""
Combines input arrays into 3-channel format expected by MobileNet1D.
All signals must be of same length (e.g., 1280).
"""
ecg = np.array(ecg).reshape(-1, 1)
eda = np.array(eda).reshape(-1, 1)
temp = np.array(temp).reshape(-1, 1)
return np.concatenate([ecg, eda, temp], axis=1)
def predict_stress(ecg, eda, temp):
try:
# ๐Ÿšฆ Preprocessing
signal = preprocess_signal(ecg, eda, temp) # Shape: (1280, 3)
signal = signal.reshape(1, 1280, 3)
# ๐Ÿ” Feature Extraction
features = feature_extractor.predict(signal)
# ๐Ÿ”„ Scaling for ELM
features_scaled = scaler.transform(features)
# ๐Ÿ”ฎ ELM Inference
raw_pred = elm.predict(features_scaled)
raw_pred = np.clip(raw_pred, -20, 20)
prob = 1 / (1 + np.exp(-raw_pred))
label = "Stress" if prob > 0.5 else "No Stress"
return f"{label} (Confidence: {prob[0][0]:.2f})"
except Exception as e:
return f"โŒ Error: {str(e)}"
# ๐Ÿงช Gradio demo interface
inputs = [
gr.Textbox(label="ECG Signal (comma-separated)", lines=2, placeholder="e.g. 0.1,0.12,...,0.13"),
gr.Textbox(label="EDA Signal (comma-separated)", lines=2, placeholder="e.g. 0.2,0.18,...,0.15"),
gr.Textbox(label="Temperature Signal (comma-separated)", lines=2, placeholder="e.g. 32.1,32.2,...,32.5")
]
def parse_and_predict(ecg_str, eda_str, temp_str):
try:
ecg = list(map(float, ecg_str.strip().split(',')))
eda = list(map(float, eda_str.strip().split(',')))
temp = list(map(float, temp_str.strip().split(',')))
if len(ecg) != 1280 or len(eda) != 1280 or len(temp) != 1280:
return "โŒ Each signal must be exactly 1280 samples long."
return predict_stress(ecg, eda, temp)
except:
return "โŒ Invalid input format. Use comma-separated values only."
demo = gr.Interface(fn=parse_and_predict, inputs=inputs, outputs="text", title="Stress Detection via Physio Signals",
description="Upload ECG, EDA, and Temperature signals (1280 samples each) to detect stress using MobileNet1D + ELM")
# ๐Ÿ Run app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_api=True
)