import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer import pickle import json # --- Register custom layers before loading model --- @tf.keras.utils.register_keras_serializable(package='PhysioWatch') class ExpandDims(Layer): def call(self, x): return tf.expand_dims(x, -1) def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[1], 1) def get_config(self): return super().get_config() @tf.keras.utils.register_keras_serializable(package='PhysioWatch') class ReduceSum(Layer): def call(self, x): return tf.reduce_sum(x, axis=1) def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[2]) def get_config(self): return super().get_config() # --- Load model and scaler --- model = tf.keras.models.load_model('physiowatch_model_clean.keras') with open('scaler.pkl', 'rb') as f: scaler = pickle.load(f) print("✅ Model and scaler loaded successfully") # --- JSON prediction endpoint --- def predict_json(json_input): try: data = json.loads(json_input) sequences = np.array(data["sequences"]) # Handle both single (24,5) and batch (n,24,5) if sequences.ndim == 2: sequences = sequences[np.newaxis, ...] n = sequences.shape[0] flat = sequences.reshape(-1, 5) normalized = scaler.transform(flat) X = normalized.reshape(n, 24, 5) probs = model.predict(X, verbose=0).flatten() results = [] for prob in probs: prob = float(prob) results.append({ "prediction": "Abnormal" if prob > 0.5 else "Normal", "confidence": round(max(prob, 1 - prob) * 100, 2), "probability_abnormal": round(prob * 100, 2), "probability_normal": round((1 - prob) * 100, 2) }) return json.dumps({"results": results, "status": "success"}, indent=2) except Exception as e: return json.dumps({"error": str(e), "status": "failed"}) # --- Manual input endpoint --- def predict_vitals(heart_rate_seq, systolic_bp_seq, diastolic_bp_seq, respiratory_rate_seq, spo2_seq): try: hr = [float(x) for x in heart_rate_seq.strip().split(',')] sbp = [float(x) for x in systolic_bp_seq.strip().split(',')] dbp = [float(x) for x in diastolic_bp_seq.strip().split(',')] rr = [float(x) for x in respiratory_rate_seq.strip().split(',')] spo2 = [float(x) for x in spo2_seq.strip().split(',')] if not all(len(s) == 24 for s in [hr, sbp, dbp, rr, spo2]): return {"error": "Each vital sign must have exactly 24 values"} raw_sequence = np.array([hr, sbp, dbp, rr, spo2]).T normalized = scaler.transform(raw_sequence) sequence = normalized.reshape(1, 24, 5) prob = float(model.predict(sequence, verbose=0)[0][0]) label = "Abnormal" if prob > 0.5 else "Normal" return { "prediction": label, "confidence": round(max(prob, 1 - prob) * 100, 2), "probability_abnormal": round(prob * 100, 2), "probability_normal": round((1 - prob) * 100, 2), "status": "success" } except Exception as e: return {"error": str(e), "status": "failed"} # --- Gradio Interface --- with gr.Blocks(title="PhysioWatch API") as demo: gr.Markdown("# PhysioWatch — Abnormal Health Pattern Detection API") gr.Markdown("CNN-BiLSTM-Attention model trained on MIMIC-IV ICU data") with gr.Tab("JSON API"): gr.Markdown("**Input format:** `{\"sequences\": [[24 timesteps × 5 vitals]]}`") gr.Markdown("**Column order:** Heart Rate, Systolic BP, Diastolic BP, Respiratory Rate, SpO2") json_in = gr.Textbox(label="JSON Input", lines=10, placeholder='{"sequences": [[[72, 120, 80, 16, 98], ...]]}') json_out = gr.Textbox(label="Result", lines=10) gr.Button("Run Prediction", variant="primary").click( predict_json, inputs=json_in, outputs=json_out ) with gr.Tab("Manual Input"): gr.Markdown("Enter exactly **24 comma-separated values** per vital sign") hr_in = gr.Textbox(label="Heart Rate (bpm)") sbp_in = gr.Textbox(label="Systolic BP (mmHg)") dbp_in = gr.Textbox(label="Diastolic BP (mmHg)") rr_in = gr.Textbox(label="Respiratory Rate (breaths/min)") spo2_in = gr.Textbox(label="SpO2 (%)") manual_out = gr.JSON(label="Result") gr.Button("Predict", variant="primary").click( predict_vitals, inputs=[hr_in, sbp_in, dbp_in, rr_in, spo2_in], outputs=manual_out ) demo.launch()