| import gradio as gr |
| import numpy as np |
| import tensorflow as tf |
| from tensorflow.keras.layers import Layer |
| import pickle |
| import json |
|
|
| |
| @tf.keras.utils.register_keras_serializable(package='PhysioWatch') |
| class ExpandDims(Layer): |
| def call(self, x): |
| return tf.expand_dims(x, -1) |
| def compute_output_shape(self, input_shape): |
| return (input_shape[0], input_shape[1], 1) |
| def get_config(self): |
| return super().get_config() |
|
|
| @tf.keras.utils.register_keras_serializable(package='PhysioWatch') |
| class ReduceSum(Layer): |
| def call(self, x): |
| return tf.reduce_sum(x, axis=1) |
| def compute_output_shape(self, input_shape): |
| return (input_shape[0], input_shape[2]) |
| def get_config(self): |
| return super().get_config() |
|
|
| |
| model = tf.keras.models.load_model('physiowatch_model_clean.keras') |
| with open('scaler.pkl', 'rb') as f: |
| scaler = pickle.load(f) |
|
|
| print("✅ Model and scaler loaded successfully") |
|
|
| |
| def predict_json(json_input): |
| try: |
| data = json.loads(json_input) |
| sequences = np.array(data["sequences"]) |
|
|
| |
| if sequences.ndim == 2: |
| sequences = sequences[np.newaxis, ...] |
|
|
| n = sequences.shape[0] |
| flat = sequences.reshape(-1, 5) |
| normalized = scaler.transform(flat) |
| X = normalized.reshape(n, 24, 5) |
|
|
| probs = model.predict(X, verbose=0).flatten() |
|
|
| results = [] |
| for prob in probs: |
| prob = float(prob) |
| results.append({ |
| "prediction": "Abnormal" if prob > 0.5 else "Normal", |
| "confidence": round(max(prob, 1 - prob) * 100, 2), |
| "probability_abnormal": round(prob * 100, 2), |
| "probability_normal": round((1 - prob) * 100, 2) |
| }) |
|
|
| return json.dumps({"results": results, "status": "success"}, indent=2) |
|
|
| except Exception as e: |
| return json.dumps({"error": str(e), "status": "failed"}) |
|
|
| |
| def predict_vitals(heart_rate_seq, systolic_bp_seq, diastolic_bp_seq, |
| respiratory_rate_seq, spo2_seq): |
| try: |
| hr = [float(x) for x in heart_rate_seq.strip().split(',')] |
| sbp = [float(x) for x in systolic_bp_seq.strip().split(',')] |
| dbp = [float(x) for x in diastolic_bp_seq.strip().split(',')] |
| rr = [float(x) for x in respiratory_rate_seq.strip().split(',')] |
| spo2 = [float(x) for x in spo2_seq.strip().split(',')] |
|
|
| if not all(len(s) == 24 for s in [hr, sbp, dbp, rr, spo2]): |
| return {"error": "Each vital sign must have exactly 24 values"} |
|
|
| raw_sequence = np.array([hr, sbp, dbp, rr, spo2]).T |
| normalized = scaler.transform(raw_sequence) |
| sequence = normalized.reshape(1, 24, 5) |
|
|
| prob = float(model.predict(sequence, verbose=0)[0][0]) |
| label = "Abnormal" if prob > 0.5 else "Normal" |
|
|
| return { |
| "prediction": label, |
| "confidence": round(max(prob, 1 - prob) * 100, 2), |
| "probability_abnormal": round(prob * 100, 2), |
| "probability_normal": round((1 - prob) * 100, 2), |
| "status": "success" |
| } |
|
|
| except Exception as e: |
| return {"error": str(e), "status": "failed"} |
|
|
| |
| with gr.Blocks(title="PhysioWatch API") as demo: |
| gr.Markdown("# PhysioWatch — Abnormal Health Pattern Detection API") |
| gr.Markdown("CNN-BiLSTM-Attention model trained on MIMIC-IV ICU data") |
|
|
| with gr.Tab("JSON API"): |
| gr.Markdown("**Input format:** `{\"sequences\": [[24 timesteps × 5 vitals]]}`") |
| gr.Markdown("**Column order:** Heart Rate, Systolic BP, Diastolic BP, Respiratory Rate, SpO2") |
| json_in = gr.Textbox(label="JSON Input", lines=10, |
| placeholder='{"sequences": [[[72, 120, 80, 16, 98], ...]]}') |
| json_out = gr.Textbox(label="Result", lines=10) |
| gr.Button("Run Prediction", variant="primary").click( |
| predict_json, inputs=json_in, outputs=json_out |
| ) |
|
|
| with gr.Tab("Manual Input"): |
| gr.Markdown("Enter exactly **24 comma-separated values** per vital sign") |
| hr_in = gr.Textbox(label="Heart Rate (bpm)") |
| sbp_in = gr.Textbox(label="Systolic BP (mmHg)") |
| dbp_in = gr.Textbox(label="Diastolic BP (mmHg)") |
| rr_in = gr.Textbox(label="Respiratory Rate (breaths/min)") |
| spo2_in = gr.Textbox(label="SpO2 (%)") |
| manual_out = gr.JSON(label="Result") |
| gr.Button("Predict", variant="primary").click( |
| predict_vitals, |
| inputs=[hr_in, sbp_in, dbp_in, rr_in, spo2_in], |
| outputs=manual_out |
| ) |
|
|
| demo.launch() |