import gradio as gr import shap import numpy as np import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from sklearn.model_selection import train_test_split # ============================================================ # 1. FAKE TRAINING DATA (Replace with your real clinical notes) # ============================================================ data = pd.DataFrame({ "note": [ "Patient experienced surgical site drainage, elevated pain, difficulty ambulating, high BMI.", "Uncomplicated post-op course, normal vitals, ambulating independently.", "Severe swelling, infection suspected, fever noted post-op day 3.", "Routine knee replacement recovery, discharged home POD2.", "History of diabetes, hypertension, wound healing slow, required re-evaluation.", "Strong recovery, no complications, outpatient follow-up scheduled." ], "readmit_30d": [1, 0, 1, 0, 1, 0] }) X = data["note"] y = data["readmit_30d"] # ================================ # 2. BUILD PIPELINE & EXTRACT COMPONENTS # ================================ vectorizer = TfidfVectorizer(stop_words="english", max_features=3000) classifier = LogisticRegression(max_iter=500) # Transform text to features X_vectorized = vectorizer.fit_transform(X) classifier.fit(X_vectorized, y) # Create pipeline for easy prediction model = Pipeline([ ("tfidf", vectorizer), ("clf", classifier) ]) # ================================ # 3. SHAP EXPLAINER SETUP # ================================ # Use LinearExplainer for linear models (much faster and more appropriate) # Use a sample of vectorized features as background background_data = X_vectorized[:3].toarray() # Use first 3 samples as background explainer = shap.LinearExplainer( classifier, background_data ) # ================================ # 4. PREDICTION + SHAP FUNCTION # ================================ def predict_note(note): proba = model.predict_proba([note])[0][1] label = "High Readmission Risk" if proba >= 0.5 else "Low Readmission Risk" return float(proba), label def explain_note(note): try: # Ensure note is a string note_str = str(note) # Transform the note to vectorized features note_vectorized = vectorizer.transform([note_str]).toarray() # Get SHAP values for the vectorized features shap_values = explainer.shap_values(note_vectorized) # Handle binary classification if isinstance(shap_values, list): shap_vals = shap_values[1] # Get class 1 (positive class) SHAP values else: shap_vals = shap_values # Get feature names (words) from vectorizer feature_names = vectorizer.get_feature_names_out() # Get top contributing features shap_vals_flat = shap_vals[0] # Flatten to 1D # Get indices sorted by absolute SHAP value top_indices = np.argsort(np.abs(shap_vals_flat))[-20:][::-1] # Top 20 features # Create HTML explanation showing top contributing words html_parts = ["
"] html_parts.append("

Top Contributing Words:

") html_parts.append("") html_parts.append("") for idx in top_indices: word = feature_names[idx] shap_val = shap_vals_flat[idx] color = "red" if shap_val > 0 else "blue" impact = "↑ Increases" if shap_val > 0 else "↓ Decreases" html_parts.append( f"" ) html_parts.append("
WordSHAP ValueImpact
{word}{shap_val:.4f}{impact}
") # Get expected value (base prediction) if isinstance(explainer.expected_value, (list, np.ndarray)): base_val = explainer.expected_value[1] if len(explainer.expected_value) > 1 else explainer.expected_value[0] else: base_val = explainer.expected_value html_parts.append(f"

Base value: {base_val:.4f}

") html_parts.append("
") return "".join(html_parts) except Exception as e: return f"Error generating explanation: {str(e)}\nPlease try a different note." # ================================ # 5. GRADIO UI # ================================ def full_pipeline(note): proba, label = predict_note(note) shap_html = explain_note(note) return ( f"Readmission Probability: {proba:.3f}\nPrediction: {label}", shap_html ) with gr.Blocks() as demo: gr.Markdown("# 🏥 Knee Replacement 30-Day Readmission Predictor\n### NLP + SHAP Explainability") input_note = gr.Textbox( label="Enter Clinical Note", placeholder="Example: Patient reports severe swelling and fever on post-op day 3..." ) btn = gr.Button("Predict Readmission Risk") output_pred = gr.Textbox(label="Model Prediction") output_shap = gr.HTML(label="SHAP Explanation") btn.click(full_pipeline, inputs=input_note, outputs=[output_pred, output_shap]) demo.launch(share=True)