Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import shap | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.model_selection import train_test_split | |
| # ============================================================ | |
| # 1. FAKE TRAINING DATA (Replace with your real clinical notes) | |
| # ============================================================ | |
| data = pd.DataFrame({ | |
| "note": [ | |
| "Patient experienced surgical site drainage, elevated pain, difficulty ambulating, high BMI.", | |
| "Uncomplicated post-op course, normal vitals, ambulating independently.", | |
| "Severe swelling, infection suspected, fever noted post-op day 3.", | |
| "Routine knee replacement recovery, discharged home POD2.", | |
| "History of diabetes, hypertension, wound healing slow, required re-evaluation.", | |
| "Strong recovery, no complications, outpatient follow-up scheduled." | |
| ], | |
| "readmit_30d": [1, 0, 1, 0, 1, 0] | |
| }) | |
| X = data["note"] | |
| y = data["readmit_30d"] | |
| # ================================ | |
| # 2. BUILD PIPELINE & EXTRACT COMPONENTS | |
| # ================================ | |
| vectorizer = TfidfVectorizer(stop_words="english", max_features=3000) | |
| classifier = LogisticRegression(max_iter=500) | |
| # Transform text to features | |
| X_vectorized = vectorizer.fit_transform(X) | |
| classifier.fit(X_vectorized, y) | |
| # Create pipeline for easy prediction | |
| model = Pipeline([ | |
| ("tfidf", vectorizer), | |
| ("clf", classifier) | |
| ]) | |
| # ================================ | |
| # 3. SHAP EXPLAINER SETUP | |
| # ================================ | |
| # Use LinearExplainer for linear models (much faster and more appropriate) | |
| # Use a sample of vectorized features as background | |
| background_data = X_vectorized[:3].toarray() # Use first 3 samples as background | |
| explainer = shap.LinearExplainer( | |
| classifier, | |
| background_data | |
| ) | |
| # ================================ | |
| # 4. PREDICTION + SHAP FUNCTION | |
| # ================================ | |
| def predict_note(note): | |
| proba = model.predict_proba([note])[0][1] | |
| label = "High Readmission Risk" if proba >= 0.5 else "Low Readmission Risk" | |
| return float(proba), label | |
| def explain_note(note): | |
| try: | |
| # Ensure note is a string | |
| note_str = str(note) | |
| # Transform the note to vectorized features | |
| note_vectorized = vectorizer.transform([note_str]).toarray() | |
| # Get SHAP values for the vectorized features | |
| shap_values = explainer.shap_values(note_vectorized) | |
| # Handle binary classification | |
| if isinstance(shap_values, list): | |
| shap_vals = shap_values[1] # Get class 1 (positive class) SHAP values | |
| else: | |
| shap_vals = shap_values | |
| # Get feature names (words) from vectorizer | |
| feature_names = vectorizer.get_feature_names_out() | |
| # Get top contributing features | |
| shap_vals_flat = shap_vals[0] # Flatten to 1D | |
| # Get indices sorted by absolute SHAP value | |
| top_indices = np.argsort(np.abs(shap_vals_flat))[-20:][::-1] # Top 20 features | |
| # Create HTML explanation showing top contributing words | |
| html_parts = ["<div style='font-family: monospace; padding: 10px;'>"] | |
| html_parts.append("<h4>Top Contributing Words:</h4>") | |
| html_parts.append("<table border='1' style='border-collapse: collapse; width: 100%;'>") | |
| html_parts.append("<tr><th>Word</th><th>SHAP Value</th><th>Impact</th></tr>") | |
| for idx in top_indices: | |
| word = feature_names[idx] | |
| shap_val = shap_vals_flat[idx] | |
| color = "red" if shap_val > 0 else "blue" | |
| impact = "↑ Increases" if shap_val > 0 else "↓ Decreases" | |
| html_parts.append( | |
| f"<tr><td>{word}</td><td style='color: {color};'>{shap_val:.4f}</td><td>{impact}</td></tr>" | |
| ) | |
| html_parts.append("</table>") | |
| # Get expected value (base prediction) | |
| if isinstance(explainer.expected_value, (list, np.ndarray)): | |
| base_val = explainer.expected_value[1] if len(explainer.expected_value) > 1 else explainer.expected_value[0] | |
| else: | |
| base_val = explainer.expected_value | |
| html_parts.append(f"<p><strong>Base value:</strong> {base_val:.4f}</p>") | |
| html_parts.append("</div>") | |
| return "".join(html_parts) | |
| except Exception as e: | |
| return f"Error generating explanation: {str(e)}\nPlease try a different note." | |
| # ================================ | |
| # 5. GRADIO UI | |
| # ================================ | |
| def full_pipeline(note): | |
| proba, label = predict_note(note) | |
| shap_html = explain_note(note) | |
| return ( | |
| f"Readmission Probability: {proba:.3f}\nPrediction: {label}", | |
| shap_html | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🏥 Knee Replacement 30-Day Readmission Predictor\n### NLP + SHAP Explainability") | |
| input_note = gr.Textbox( | |
| label="Enter Clinical Note", | |
| placeholder="Example: Patient reports severe swelling and fever on post-op day 3..." | |
| ) | |
| btn = gr.Button("Predict Readmission Risk") | |
| output_pred = gr.Textbox(label="Model Prediction") | |
| output_shap = gr.HTML(label="SHAP Explanation") | |
| btn.click(full_pipeline, inputs=input_note, outputs=[output_pred, output_shap]) | |
| demo.launch(share=True) |