Spaces:
Sleeping
Sleeping
File size: 5,427 Bytes
ff570a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import gradio as gr
import shap
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
# ============================================================
# 1. FAKE TRAINING DATA (Replace with your real clinical notes)
# ============================================================
data = pd.DataFrame({
"note": [
"Patient experienced surgical site drainage, elevated pain, difficulty ambulating, high BMI.",
"Uncomplicated post-op course, normal vitals, ambulating independently.",
"Severe swelling, infection suspected, fever noted post-op day 3.",
"Routine knee replacement recovery, discharged home POD2.",
"History of diabetes, hypertension, wound healing slow, required re-evaluation.",
"Strong recovery, no complications, outpatient follow-up scheduled."
],
"readmit_30d": [1, 0, 1, 0, 1, 0]
})
X = data["note"]
y = data["readmit_30d"]
# ================================
# 2. BUILD PIPELINE & EXTRACT COMPONENTS
# ================================
vectorizer = TfidfVectorizer(stop_words="english", max_features=3000)
classifier = LogisticRegression(max_iter=500)
# Transform text to features
X_vectorized = vectorizer.fit_transform(X)
classifier.fit(X_vectorized, y)
# Create pipeline for easy prediction
model = Pipeline([
("tfidf", vectorizer),
("clf", classifier)
])
# ================================
# 3. SHAP EXPLAINER SETUP
# ================================
# Use LinearExplainer for linear models (much faster and more appropriate)
# Use a sample of vectorized features as background
background_data = X_vectorized[:3].toarray() # Use first 3 samples as background
explainer = shap.LinearExplainer(
classifier,
background_data
)
# ================================
# 4. PREDICTION + SHAP FUNCTION
# ================================
def predict_note(note):
proba = model.predict_proba([note])[0][1]
label = "High Readmission Risk" if proba >= 0.5 else "Low Readmission Risk"
return float(proba), label
def explain_note(note):
try:
# Ensure note is a string
note_str = str(note)
# Transform the note to vectorized features
note_vectorized = vectorizer.transform([note_str]).toarray()
# Get SHAP values for the vectorized features
shap_values = explainer.shap_values(note_vectorized)
# Handle binary classification
if isinstance(shap_values, list):
shap_vals = shap_values[1] # Get class 1 (positive class) SHAP values
else:
shap_vals = shap_values
# Get feature names (words) from vectorizer
feature_names = vectorizer.get_feature_names_out()
# Get top contributing features
shap_vals_flat = shap_vals[0] # Flatten to 1D
# Get indices sorted by absolute SHAP value
top_indices = np.argsort(np.abs(shap_vals_flat))[-20:][::-1] # Top 20 features
# Create HTML explanation showing top contributing words
html_parts = ["<div style='font-family: monospace; padding: 10px;'>"]
html_parts.append("<h4>Top Contributing Words:</h4>")
html_parts.append("<table border='1' style='border-collapse: collapse; width: 100%;'>")
html_parts.append("<tr><th>Word</th><th>SHAP Value</th><th>Impact</th></tr>")
for idx in top_indices:
word = feature_names[idx]
shap_val = shap_vals_flat[idx]
color = "red" if shap_val > 0 else "blue"
impact = "↑ Increases" if shap_val > 0 else "↓ Decreases"
html_parts.append(
f"<tr><td>{word}</td><td style='color: {color};'>{shap_val:.4f}</td><td>{impact}</td></tr>"
)
html_parts.append("</table>")
# Get expected value (base prediction)
if isinstance(explainer.expected_value, (list, np.ndarray)):
base_val = explainer.expected_value[1] if len(explainer.expected_value) > 1 else explainer.expected_value[0]
else:
base_val = explainer.expected_value
html_parts.append(f"<p><strong>Base value:</strong> {base_val:.4f}</p>")
html_parts.append("</div>")
return "".join(html_parts)
except Exception as e:
return f"Error generating explanation: {str(e)}\nPlease try a different note."
# ================================
# 5. GRADIO UI
# ================================
def full_pipeline(note):
proba, label = predict_note(note)
shap_html = explain_note(note)
return (
f"Readmission Probability: {proba:.3f}\nPrediction: {label}",
shap_html
)
with gr.Blocks() as demo:
gr.Markdown("# 🏥 Knee Replacement 30-Day Readmission Predictor\n### NLP + SHAP Explainability")
input_note = gr.Textbox(
label="Enter Clinical Note",
placeholder="Example: Patient reports severe swelling and fever on post-op day 3..."
)
btn = gr.Button("Predict Readmission Risk")
output_pred = gr.Textbox(label="Model Prediction")
output_shap = gr.HTML(label="SHAP Explanation")
btn.click(full_pipeline, inputs=input_note, outputs=[output_pred, output_shap])
demo.launch(share=True) |