Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- .gitattributes +35 -35
- .python_version +1 -0
- README.md +15 -16
- chatbot/chatbot.py +25 -0
- chatbot/intents.py +11 -0
- data/load_data.py +15 -0
- data/preprocess.py +7 -0
- intents.py +11 -0
- models/bart_explainer.py +32 -0
- models/baseline_tfidf.py +17 -0
- notebooks/01_exploration.ipynb +54 -0
- notebooks/02_baseline_nlp.ipynb +95 -0
- notebooks/03_llm_explanations.ipynb +88 -0
- requirements.txt +7 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.python_version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.10
|
README.md
CHANGED
|
@@ -1,16 +1,15 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Diabetes
|
| 3 |
-
emoji: 💬
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.42.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
hf_oauth: true
|
| 11 |
-
hf_oauth_scopes:
|
| 12 |
-
- inference-api
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Diabetes Readmission
|
| 3 |
+
emoji: 💬
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.42.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
hf_oauth: true
|
| 11 |
+
hf_oauth_scopes:
|
| 12 |
+
- inference-api
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
chatbot/chatbot.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chatbot.intents import detect_intent
|
| 2 |
+
from models.bart_explainer import generate_explanation
|
| 3 |
+
|
| 4 |
+
def chatbot_response(question, note, risk_score):
|
| 5 |
+
intent = detect_intent(question)
|
| 6 |
+
|
| 7 |
+
if intent == "appointment":
|
| 8 |
+
if "follow" in note.lower():
|
| 9 |
+
return "Your discharge summary mentions a follow-up. Please consult your provider for scheduling."
|
| 10 |
+
return "No specific follow-up appointment is mentioned in the discharge summary."
|
| 11 |
+
|
| 12 |
+
if intent == "severity":
|
| 13 |
+
if risk_score > 0.6:
|
| 14 |
+
level = "high"
|
| 15 |
+
elif risk_score > 0.3:
|
| 16 |
+
level = "moderate"
|
| 17 |
+
else:
|
| 18 |
+
level = "low"
|
| 19 |
+
|
| 20 |
+
return f"Your readmission risk is considered {level}."
|
| 21 |
+
|
| 22 |
+
if intent == "explanation":
|
| 23 |
+
return generate_explanation(note, risk_score)
|
| 24 |
+
|
| 25 |
+
return "I can help explain risk, follow-up, or discharge information. I cannot provide medical advice."
|
chatbot/intents.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def detect_intent(question: str) -> str:
|
| 2 |
+
q = question.lower()
|
| 3 |
+
|
| 4 |
+
if "appointment" in q or "follow up" in q:
|
| 5 |
+
return "appointment"
|
| 6 |
+
if "how serious" in q or "severity" in q or "risk" in q:
|
| 7 |
+
return "severity"
|
| 8 |
+
if "why" in q or "explain" in q:
|
| 9 |
+
return "explanation"
|
| 10 |
+
|
| 11 |
+
return "unknown"
|
data/load_data.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
|
| 3 |
+
def load_discharge_notes():
|
| 4 |
+
notes = load_dataset(
|
| 5 |
+
"ntphuc149/MIMIC-III-Clinical-Database",
|
| 6 |
+
"NOTEEVENTS",
|
| 7 |
+
split="train"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
# Filter only discharge summaries
|
| 11 |
+
notes = notes.filter(
|
| 12 |
+
lambda x: x["CATEGORY"] == "Discharge summary"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
return notes
|
data/preprocess.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def clean_text(text: str) -> str:
|
| 4 |
+
text = text.lower()
|
| 5 |
+
text = re.sub(r"\[\*\*.*?\*\*\]", "", text) # remove PHI
|
| 6 |
+
text = re.sub(r"\n+", " ", text)
|
| 7 |
+
return text.strip()
|
intents.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def detect_intent(question: str) -> str:
|
| 2 |
+
q = question.lower()
|
| 3 |
+
|
| 4 |
+
if "appointment" in q or "follow up" in q:
|
| 5 |
+
return "appointment"
|
| 6 |
+
if "how serious" in q or "severity" in q or "risk" in q:
|
| 7 |
+
return "severity"
|
| 8 |
+
if "why" in q or "explain" in q:
|
| 9 |
+
return "explanation"
|
| 10 |
+
|
| 11 |
+
return "unknown"
|
models/bart_explainer.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 2 |
+
|
| 3 |
+
MODEL_NAME = "dmacres/bart-large-mimiciii-v2"
|
| 4 |
+
|
| 5 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 6 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 7 |
+
|
| 8 |
+
def generate_explanation(note: str, risk_score: float) -> str:
|
| 9 |
+
prompt = f"""
|
| 10 |
+
Discharge summary:
|
| 11 |
+
{note}
|
| 12 |
+
|
| 13 |
+
Predicted readmission risk: {risk_score:.2f}
|
| 14 |
+
|
| 15 |
+
Explain the key clinical reasons for readmission risk.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
inputs = tokenizer(
|
| 19 |
+
prompt,
|
| 20 |
+
return_tensors="pt",
|
| 21 |
+
truncation=True,
|
| 22 |
+
max_length=1024
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
outputs = model.generate(
|
| 26 |
+
**inputs,
|
| 27 |
+
max_length=200,
|
| 28 |
+
num_beams=4,
|
| 29 |
+
early_stopping=True
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
models/baseline_tfidf.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.pipeline import Pipeline
|
| 2 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 3 |
+
from sklearn.linear_model import LogisticRegression
|
| 4 |
+
|
| 5 |
+
def build_tfidf_model():
|
| 6 |
+
model = Pipeline([
|
| 7 |
+
("tfidf", TfidfVectorizer(
|
| 8 |
+
max_features=20000,
|
| 9 |
+
ngram_range=(1, 2),
|
| 10 |
+
stop_words="english"
|
| 11 |
+
)),
|
| 12 |
+
("clf", LogisticRegression(
|
| 13 |
+
max_iter=1000,
|
| 14 |
+
class_weight="balanced"
|
| 15 |
+
))
|
| 16 |
+
])
|
| 17 |
+
return model
|
notebooks/01_exploration.ipynb
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "27d3ad11",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"vscode": {
|
| 9 |
+
"languageId": "plaintext"
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"outputs": [],
|
| 13 |
+
"source": [
|
| 14 |
+
"# Install dependencies\n",
|
| 15 |
+
"!pip install datasets pandas\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"from datasets import load_dataset\n",
|
| 18 |
+
"import pandas as pd\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"# Load NOTEEVENTS directly from Hugging Face (no manual download)\n",
|
| 21 |
+
"notes = load_dataset(\n",
|
| 22 |
+
" \"ntphuc149/MIMIC-III-Clinical-Database\",\n",
|
| 23 |
+
" \"NOTEEVENTS\",\n",
|
| 24 |
+
" split=\"train\"\n",
|
| 25 |
+
")\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"# Inspect columns\n",
|
| 28 |
+
"print(\"Columns:\", notes.column_names)\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"# Filter discharge summaries\n",
|
| 31 |
+
"discharge_notes = notes.filter(\n",
|
| 32 |
+
" lambda x: x[\"CATEGORY\"] == \"Discharge summary\"\n",
|
| 33 |
+
")\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"print(\"Total discharge summaries:\", len(discharge_notes))\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"# Convert small sample to pandas for inspection\n",
|
| 38 |
+
"df = discharge_notes.select(range(5)).to_pandas()\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# Print sample text\n",
|
| 41 |
+
"for i, text in enumerate(df[\"TEXT\"]):\n",
|
| 42 |
+
" print(f\"\\n--- Discharge Summary {i+1} ---\\n\")\n",
|
| 43 |
+
" print(text[:2000])\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"metadata": {
|
| 48 |
+
"language_info": {
|
| 49 |
+
"name": "python"
|
| 50 |
+
}
|
| 51 |
+
},
|
| 52 |
+
"nbformat": 4,
|
| 53 |
+
"nbformat_minor": 5
|
| 54 |
+
}
|
notebooks/02_baseline_nlp.ipynb
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "1d9c9ac7",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"vscode": {
|
| 9 |
+
"languageId": "plaintext"
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"outputs": [],
|
| 13 |
+
"source": [
|
| 14 |
+
"# Install dependencies\n",
|
| 15 |
+
"!pip install datasets scikit-learn pandas numpy\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"from datasets import load_dataset\n",
|
| 18 |
+
"import pandas as pd\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import re\n",
|
| 21 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 22 |
+
"from sklearn.pipeline import Pipeline\n",
|
| 23 |
+
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
| 24 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 25 |
+
"from sklearn.metrics import classification_report, roc_auc_score\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"# Load discharge summaries\n",
|
| 28 |
+
"notes = load_dataset(\n",
|
| 29 |
+
" \"ntphuc149/MIMIC-III-Clinical-Database\",\n",
|
| 30 |
+
" \"NOTEEVENTS\",\n",
|
| 31 |
+
" split=\"train\"\n",
|
| 32 |
+
")\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"notes = notes.filter(lambda x: x[\"CATEGORY\"] == \"Discharge summary\")\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"# Minimal clinical-safe cleaning\n",
|
| 37 |
+
"def clean_text(text):\n",
|
| 38 |
+
" text = text.lower()\n",
|
| 39 |
+
" text = re.sub(r\"\\[\\*\\*.*?\\*\\*\\]\", \"\", text)\n",
|
| 40 |
+
" text = re.sub(r\"\\n+\", \" \", text)\n",
|
| 41 |
+
" return text.strip()\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"notes = notes.map(lambda x: {\"clean_text\": clean_text(x[\"TEXT\"])})\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# TEMPORARY labels (replace with READMITTED_30D later)\n",
|
| 46 |
+
"np.random.seed(42)\n",
|
| 47 |
+
"labels = np.random.binomial(1, 0.35, size=len(notes))\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"df = pd.DataFrame({\n",
|
| 50 |
+
" \"text\": notes[\"clean_text\"],\n",
|
| 51 |
+
" \"label\": labels\n",
|
| 52 |
+
"})\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# Train-test split\n",
|
| 55 |
+
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
| 56 |
+
" df[\"text\"],\n",
|
| 57 |
+
" df[\"label\"],\n",
|
| 58 |
+
" test_size=0.2,\n",
|
| 59 |
+
" stratify=df[\"label\"],\n",
|
| 60 |
+
" random_state=42\n",
|
| 61 |
+
")\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# TF-IDF baseline model\n",
|
| 64 |
+
"model = Pipeline([\n",
|
| 65 |
+
" (\"tfidf\", TfidfVectorizer(\n",
|
| 66 |
+
" max_features=20000,\n",
|
| 67 |
+
" ngram_range=(1,2),\n",
|
| 68 |
+
" stop_words=\"english\"\n",
|
| 69 |
+
" )),\n",
|
| 70 |
+
" (\"clf\", LogisticRegression(\n",
|
| 71 |
+
" max_iter=1000,\n",
|
| 72 |
+
" class_weight=\"balanced\"\n",
|
| 73 |
+
" ))\n",
|
| 74 |
+
"])\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Train\n",
|
| 77 |
+
"model.fit(X_train, y_train)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# Evaluate\n",
|
| 80 |
+
"y_pred = model.predict(X_test)\n",
|
| 81 |
+
"y_prob = model.predict_proba(X_test)[:,1]\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"print(classification_report(y_test, y_pred))\n",
|
| 84 |
+
"print(\"ROC-AUC:\", roc_auc_score(y_test, y_prob))\n"
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
],
|
| 88 |
+
"metadata": {
|
| 89 |
+
"language_info": {
|
| 90 |
+
"name": "python"
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"nbformat": 4,
|
| 94 |
+
"nbformat_minor": 5
|
| 95 |
+
}
|
notebooks/03_llm_explanations.ipynb
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "ca5c8ec7",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"vscode": {
|
| 9 |
+
"languageId": "plaintext"
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"outputs": [],
|
| 13 |
+
"source": [
|
| 14 |
+
"# Install dependencies\n",
|
| 15 |
+
"!pip install transformers torch datasets\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
|
| 18 |
+
"from datasets import load_dataset\n",
|
| 19 |
+
"import re\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"# Load pretrained clinical BART model\n",
|
| 22 |
+
"MODEL_NAME = \"dmacres/bart-large-mimiciii-v2\"\n",
|
| 23 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 24 |
+
"model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"# Load discharge summaries\n",
|
| 27 |
+
"notes = load_dataset(\n",
|
| 28 |
+
" \"ntphuc149/MIMIC-III-Clinical-Database\",\n",
|
| 29 |
+
" \"NOTEEVENTS\",\n",
|
| 30 |
+
" split=\"train\"\n",
|
| 31 |
+
")\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"notes = notes.filter(lambda x: x[\"CATEGORY\"] == \"Discharge summary\")\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"# Clean text\n",
|
| 36 |
+
"def clean_text(text):\n",
|
| 37 |
+
" text = text.lower()\n",
|
| 38 |
+
" text = re.sub(r\"\\[\\*\\*.*?\\*\\*\\]\", \"\", text)\n",
|
| 39 |
+
" text = re.sub(r\"\\n+\", \" \", text)\n",
|
| 40 |
+
" return text.strip()\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"sample_note = clean_text(notes[0][\"TEXT\"])\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"# GenAI explanation function\n",
|
| 45 |
+
"def generate_explanation(note, risk_score):\n",
|
| 46 |
+
" prompt = f\"\"\"\n",
|
| 47 |
+
"Discharge summary:\n",
|
| 48 |
+
"{note}\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"Predicted readmission risk: {risk_score:.2f}\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"Explain the key clinical reasons for readmission risk.\n",
|
| 53 |
+
"\"\"\"\n",
|
| 54 |
+
"\n",
|
| 55 |
+
" inputs = tokenizer(\n",
|
| 56 |
+
" prompt,\n",
|
| 57 |
+
" return_tensors=\"pt\",\n",
|
| 58 |
+
" truncation=True,\n",
|
| 59 |
+
" max_length=1024\n",
|
| 60 |
+
" )\n",
|
| 61 |
+
"\n",
|
| 62 |
+
" outputs = model.generate(\n",
|
| 63 |
+
" **inputs,\n",
|
| 64 |
+
" max_length=200,\n",
|
| 65 |
+
" num_beams=4,\n",
|
| 66 |
+
" early_stopping=True\n",
|
| 67 |
+
" )\n",
|
| 68 |
+
"\n",
|
| 69 |
+
" return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"# Demo\n",
|
| 72 |
+
"risk_score = 0.72 # from your classifier\n",
|
| 73 |
+
"explanation = generate_explanation(sample_note, risk_score)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"print(\"Predicted Risk:\", risk_score)\n",
|
| 76 |
+
"print(\"\\nGenerated Explanation:\\n\")\n",
|
| 77 |
+
"print(explanation)\n"
|
| 78 |
+
]
|
| 79 |
+
}
|
| 80 |
+
],
|
| 81 |
+
"metadata": {
|
| 82 |
+
"language_info": {
|
| 83 |
+
"name": "python"
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"nbformat": 4,
|
| 87 |
+
"nbformat_minor": 5
|
| 88 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.36.2
|
| 2 |
+
huggingface_hub==0.20.3
|
| 3 |
+
datasets==2.16.1
|
| 4 |
+
torch
|
| 5 |
+
scikit-learn
|
| 6 |
+
pandas
|
| 7 |
+
numpy
|