osamos360's picture
Update app.py
75f4f04 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
import pickle
import json
# --- Register custom layers before loading model ---
@tf.keras.utils.register_keras_serializable(package='PhysioWatch')
class ExpandDims(Layer):
def call(self, x):
return tf.expand_dims(x, -1)
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1], 1)
def get_config(self):
return super().get_config()
@tf.keras.utils.register_keras_serializable(package='PhysioWatch')
class ReduceSum(Layer):
def call(self, x):
return tf.reduce_sum(x, axis=1)
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[2])
def get_config(self):
return super().get_config()
# --- Load model and scaler ---
model = tf.keras.models.load_model('physiowatch_model_clean.keras')
with open('scaler.pkl', 'rb') as f:
scaler = pickle.load(f)
print("✅ Model and scaler loaded successfully")
# --- JSON prediction endpoint ---
def predict_json(json_input):
try:
data = json.loads(json_input)
sequences = np.array(data["sequences"])
# Handle both single (24,5) and batch (n,24,5)
if sequences.ndim == 2:
sequences = sequences[np.newaxis, ...]
n = sequences.shape[0]
flat = sequences.reshape(-1, 5)
normalized = scaler.transform(flat)
X = normalized.reshape(n, 24, 5)
probs = model.predict(X, verbose=0).flatten()
results = []
for prob in probs:
prob = float(prob)
results.append({
"prediction": "Abnormal" if prob > 0.5 else "Normal",
"confidence": round(max(prob, 1 - prob) * 100, 2),
"probability_abnormal": round(prob * 100, 2),
"probability_normal": round((1 - prob) * 100, 2)
})
return json.dumps({"results": results, "status": "success"}, indent=2)
except Exception as e:
return json.dumps({"error": str(e), "status": "failed"})
# --- Manual input endpoint ---
def predict_vitals(heart_rate_seq, systolic_bp_seq, diastolic_bp_seq,
respiratory_rate_seq, spo2_seq):
try:
hr = [float(x) for x in heart_rate_seq.strip().split(',')]
sbp = [float(x) for x in systolic_bp_seq.strip().split(',')]
dbp = [float(x) for x in diastolic_bp_seq.strip().split(',')]
rr = [float(x) for x in respiratory_rate_seq.strip().split(',')]
spo2 = [float(x) for x in spo2_seq.strip().split(',')]
if not all(len(s) == 24 for s in [hr, sbp, dbp, rr, spo2]):
return {"error": "Each vital sign must have exactly 24 values"}
raw_sequence = np.array([hr, sbp, dbp, rr, spo2]).T
normalized = scaler.transform(raw_sequence)
sequence = normalized.reshape(1, 24, 5)
prob = float(model.predict(sequence, verbose=0)[0][0])
label = "Abnormal" if prob > 0.5 else "Normal"
return {
"prediction": label,
"confidence": round(max(prob, 1 - prob) * 100, 2),
"probability_abnormal": round(prob * 100, 2),
"probability_normal": round((1 - prob) * 100, 2),
"status": "success"
}
except Exception as e:
return {"error": str(e), "status": "failed"}
# --- Gradio Interface ---
with gr.Blocks(title="PhysioWatch API") as demo:
gr.Markdown("# PhysioWatch — Abnormal Health Pattern Detection API")
gr.Markdown("CNN-BiLSTM-Attention model trained on MIMIC-IV ICU data")
with gr.Tab("JSON API"):
gr.Markdown("**Input format:** `{\"sequences\": [[24 timesteps × 5 vitals]]}`")
gr.Markdown("**Column order:** Heart Rate, Systolic BP, Diastolic BP, Respiratory Rate, SpO2")
json_in = gr.Textbox(label="JSON Input", lines=10,
placeholder='{"sequences": [[[72, 120, 80, 16, 98], ...]]}')
json_out = gr.Textbox(label="Result", lines=10)
gr.Button("Run Prediction", variant="primary").click(
predict_json, inputs=json_in, outputs=json_out
)
with gr.Tab("Manual Input"):
gr.Markdown("Enter exactly **24 comma-separated values** per vital sign")
hr_in = gr.Textbox(label="Heart Rate (bpm)")
sbp_in = gr.Textbox(label="Systolic BP (mmHg)")
dbp_in = gr.Textbox(label="Diastolic BP (mmHg)")
rr_in = gr.Textbox(label="Respiratory Rate (breaths/min)")
spo2_in = gr.Textbox(label="SpO2 (%)")
manual_out = gr.JSON(label="Result")
gr.Button("Predict", variant="primary").click(
predict_vitals,
inputs=[hr_in, sbp_in, dbp_in, rr_in, spo2_in],
outputs=manual_out
)
demo.launch()