SHAP_NLP_TKA / app.py
LianHP's picture
Upload folder using huggingface_hub
ff570a2 verified
import gradio as gr
import shap
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
# ============================================================
# 1. FAKE TRAINING DATA (Replace with your real clinical notes)
# ============================================================
data = pd.DataFrame({
"note": [
"Patient experienced surgical site drainage, elevated pain, difficulty ambulating, high BMI.",
"Uncomplicated post-op course, normal vitals, ambulating independently.",
"Severe swelling, infection suspected, fever noted post-op day 3.",
"Routine knee replacement recovery, discharged home POD2.",
"History of diabetes, hypertension, wound healing slow, required re-evaluation.",
"Strong recovery, no complications, outpatient follow-up scheduled."
],
"readmit_30d": [1, 0, 1, 0, 1, 0]
})
X = data["note"]
y = data["readmit_30d"]
# ================================
# 2. BUILD PIPELINE & EXTRACT COMPONENTS
# ================================
vectorizer = TfidfVectorizer(stop_words="english", max_features=3000)
classifier = LogisticRegression(max_iter=500)
# Transform text to features
X_vectorized = vectorizer.fit_transform(X)
classifier.fit(X_vectorized, y)
# Create pipeline for easy prediction
model = Pipeline([
("tfidf", vectorizer),
("clf", classifier)
])
# ================================
# 3. SHAP EXPLAINER SETUP
# ================================
# Use LinearExplainer for linear models (much faster and more appropriate)
# Use a sample of vectorized features as background
background_data = X_vectorized[:3].toarray() # Use first 3 samples as background
explainer = shap.LinearExplainer(
classifier,
background_data
)
# ================================
# 4. PREDICTION + SHAP FUNCTION
# ================================
def predict_note(note):
proba = model.predict_proba([note])[0][1]
label = "High Readmission Risk" if proba >= 0.5 else "Low Readmission Risk"
return float(proba), label
def explain_note(note):
try:
# Ensure note is a string
note_str = str(note)
# Transform the note to vectorized features
note_vectorized = vectorizer.transform([note_str]).toarray()
# Get SHAP values for the vectorized features
shap_values = explainer.shap_values(note_vectorized)
# Handle binary classification
if isinstance(shap_values, list):
shap_vals = shap_values[1] # Get class 1 (positive class) SHAP values
else:
shap_vals = shap_values
# Get feature names (words) from vectorizer
feature_names = vectorizer.get_feature_names_out()
# Get top contributing features
shap_vals_flat = shap_vals[0] # Flatten to 1D
# Get indices sorted by absolute SHAP value
top_indices = np.argsort(np.abs(shap_vals_flat))[-20:][::-1] # Top 20 features
# Create HTML explanation showing top contributing words
html_parts = ["<div style='font-family: monospace; padding: 10px;'>"]
html_parts.append("<h4>Top Contributing Words:</h4>")
html_parts.append("<table border='1' style='border-collapse: collapse; width: 100%;'>")
html_parts.append("<tr><th>Word</th><th>SHAP Value</th><th>Impact</th></tr>")
for idx in top_indices:
word = feature_names[idx]
shap_val = shap_vals_flat[idx]
color = "red" if shap_val > 0 else "blue"
impact = "↑ Increases" if shap_val > 0 else "↓ Decreases"
html_parts.append(
f"<tr><td>{word}</td><td style='color: {color};'>{shap_val:.4f}</td><td>{impact}</td></tr>"
)
html_parts.append("</table>")
# Get expected value (base prediction)
if isinstance(explainer.expected_value, (list, np.ndarray)):
base_val = explainer.expected_value[1] if len(explainer.expected_value) > 1 else explainer.expected_value[0]
else:
base_val = explainer.expected_value
html_parts.append(f"<p><strong>Base value:</strong> {base_val:.4f}</p>")
html_parts.append("</div>")
return "".join(html_parts)
except Exception as e:
return f"Error generating explanation: {str(e)}\nPlease try a different note."
# ================================
# 5. GRADIO UI
# ================================
def full_pipeline(note):
proba, label = predict_note(note)
shap_html = explain_note(note)
return (
f"Readmission Probability: {proba:.3f}\nPrediction: {label}",
shap_html
)
with gr.Blocks() as demo:
gr.Markdown("# 🏥 Knee Replacement 30-Day Readmission Predictor\n### NLP + SHAP Explainability")
input_note = gr.Textbox(
label="Enter Clinical Note",
placeholder="Example: Patient reports severe swelling and fever on post-op day 3..."
)
btn = gr.Button("Predict Readmission Risk")
output_pred = gr.Textbox(label="Model Prediction")
output_shap = gr.HTML(label="SHAP Explanation")
btn.click(full_pipeline, inputs=input_note, outputs=[output_pred, output_shap])
demo.launch(share=True)