File size: 5,427 Bytes
ff570a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import shap
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split

# ============================================================
# 1. FAKE TRAINING DATA (Replace with your real clinical notes)
# ============================================================

data = pd.DataFrame({
    "note": [
        "Patient experienced surgical site drainage, elevated pain, difficulty ambulating, high BMI.",
        "Uncomplicated post-op course, normal vitals, ambulating independently.",
        "Severe swelling, infection suspected, fever noted post-op day 3.",
        "Routine knee replacement recovery, discharged home POD2.",
        "History of diabetes, hypertension, wound healing slow, required re-evaluation.",
        "Strong recovery, no complications, outpatient follow-up scheduled."
    ],
    "readmit_30d": [1, 0, 1, 0, 1, 0]
})

X = data["note"]
y = data["readmit_30d"]

# ================================
# 2. BUILD PIPELINE & EXTRACT COMPONENTS
# ================================
vectorizer = TfidfVectorizer(stop_words="english", max_features=3000)
classifier = LogisticRegression(max_iter=500)

# Transform text to features
X_vectorized = vectorizer.fit_transform(X)
classifier.fit(X_vectorized, y)

# Create pipeline for easy prediction
model = Pipeline([
    ("tfidf", vectorizer),
    ("clf", classifier)
])

# ================================
# 3. SHAP EXPLAINER SETUP
# ================================
# Use LinearExplainer for linear models (much faster and more appropriate)
# Use a sample of vectorized features as background
background_data = X_vectorized[:3].toarray()  # Use first 3 samples as background
explainer = shap.LinearExplainer(
    classifier,
    background_data
)

# ================================
# 4. PREDICTION + SHAP FUNCTION
# ================================
def predict_note(note):
    proba = model.predict_proba([note])[0][1]
    label = "High Readmission Risk" if proba >= 0.5 else "Low Readmission Risk"
    return float(proba), label


def explain_note(note):
    try:
        # Ensure note is a string
        note_str = str(note)
        
        # Transform the note to vectorized features
        note_vectorized = vectorizer.transform([note_str]).toarray()
        
        # Get SHAP values for the vectorized features
        shap_values = explainer.shap_values(note_vectorized)
        
        # Handle binary classification
        if isinstance(shap_values, list):
            shap_vals = shap_values[1]  # Get class 1 (positive class) SHAP values
        else:
            shap_vals = shap_values
        
        # Get feature names (words) from vectorizer
        feature_names = vectorizer.get_feature_names_out()
        
        # Get top contributing features
        shap_vals_flat = shap_vals[0]  # Flatten to 1D
        
        # Get indices sorted by absolute SHAP value
        top_indices = np.argsort(np.abs(shap_vals_flat))[-20:][::-1]  # Top 20 features
        
        # Create HTML explanation showing top contributing words
        html_parts = ["<div style='font-family: monospace; padding: 10px;'>"]
        html_parts.append("<h4>Top Contributing Words:</h4>")
        html_parts.append("<table border='1' style='border-collapse: collapse; width: 100%;'>")
        html_parts.append("<tr><th>Word</th><th>SHAP Value</th><th>Impact</th></tr>")
        
        for idx in top_indices:
            word = feature_names[idx]
            shap_val = shap_vals_flat[idx]
            color = "red" if shap_val > 0 else "blue"
            impact = "↑ Increases" if shap_val > 0 else "↓ Decreases"
            html_parts.append(
                f"<tr><td>{word}</td><td style='color: {color};'>{shap_val:.4f}</td><td>{impact}</td></tr>"
            )
        
        html_parts.append("</table>")
        # Get expected value (base prediction)
        if isinstance(explainer.expected_value, (list, np.ndarray)):
            base_val = explainer.expected_value[1] if len(explainer.expected_value) > 1 else explainer.expected_value[0]
        else:
            base_val = explainer.expected_value
        html_parts.append(f"<p><strong>Base value:</strong> {base_val:.4f}</p>")
        html_parts.append("</div>")
        
        return "".join(html_parts)
        
    except Exception as e:
        return f"Error generating explanation: {str(e)}\nPlease try a different note."


# ================================
# 5. GRADIO UI
# ================================
def full_pipeline(note):
    proba, label = predict_note(note)
    shap_html = explain_note(note)
    return (
        f"Readmission Probability: {proba:.3f}\nPrediction: {label}",
        shap_html
    )


with gr.Blocks() as demo:
    gr.Markdown("# 🏥 Knee Replacement 30-Day Readmission Predictor\n### NLP + SHAP Explainability")

    input_note = gr.Textbox(
        label="Enter Clinical Note",
        placeholder="Example: Patient reports severe swelling and fever on post-op day 3..."
    )

    btn = gr.Button("Predict Readmission Risk")

    output_pred = gr.Textbox(label="Model Prediction")
    output_shap = gr.HTML(label="SHAP Explanation")

    btn.click(full_pipeline, inputs=input_note, outputs=[output_pred, output_shap])

demo.launch(share=True)