Spaces:
Sleeping
Sleeping
File size: 6,392 Bytes
aca21d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple
DATASET_PATH = "differential_diagnosis_dataset.csv"
def load_dataset(path: str):
df = pd.read_csv(path)
# Identify diagnosis column (case-insensitive exact match or contains)
cols_lower = [c.lower() for c in df.columns]
if "diagnosis" in cols_lower:
diagnosis_col = df.columns[cols_lower.index("diagnosis")]
else:
diagnosis_col = None
for c in df.columns:
if "diagnos" in c.lower():
diagnosis_col = c
break
if diagnosis_col is None:
raise ValueError("No diagnosis column found. Please ensure a 'diagnosis' column exists.")
# Symptom columns heuristic
symptom_cols = [c for c in df.columns if ("symptom" in c.lower()) or ("feature" in c.lower()) or ("sign" in c.lower())]
if len(symptom_cols) == 0:
symptom_cols = [c for c in df.columns if c != diagnosis_col]
# Build a 'text' column by concatenating symptoms
def row_to_text(row):
parts = []
for c in symptom_cols:
val = row.get(c, "")
if pd.notna(val) and str(val).strip():
parts.append(str(val))
return ", ".join(parts)
df = df.copy()
df["_symptom_text"] = df.apply(row_to_text, axis=1)
return df, diagnosis_col, symptom_cols
# Global fit (loaded at startup)
try:
_df, _diagnosis_col, _symptom_cols = load_dataset(DATASET_PATH)
_corpus = _df["_symptom_text"].fillna("").astype(str).tolist()
_vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1)
_X = _vectorizer.fit_transform(_corpus)
except Exception as e:
_df, _diagnosis_col, _symptom_cols = None, None, None
_corpus, _vectorizer, _X = [], None, None
_startup_error = str(e)
else:
_startup_error = ""
def normalize_symptoms(text: str) -> str:
return ", ".join([t.strip().lower() for t in text.split(",") if t.strip()])
def few_shot_prompt(symptoms: List[str], k_examples: int = 3) -> str:
if _X is None or _vectorizer is None:
return "Model not initialized."
query = ", ".join(symptoms)
qv = _vectorizer.transform([query])
sims = cosine_similarity(qv, _X).ravel()
top_idx = sims.argsort()[::-1][:k_examples]
examples = []
for i in top_idx:
examples.append(f"Symptoms: {{ {_df.loc[i, '_symptom_text']} }} -> Diagnosis: {{ {_df.loc[i, _diagnosis_col]} }}")
return "\n".join(examples)
def chain_of_reasoning(symptoms: List[str]) -> List[str]:
steps = []
steps.append("1) Parse input and standardize terms.")
steps.append("2) Match symptom pattern with dataset using TF-IDF similarity.")
steps.append("3) Collect top diagnoses and compute normalized scores.")
steps.append("4) Check for red-flag patterns (e.g., chest pain + shortness of breath).")
steps.append("5) Return differential list with triage flags. For care, consult a clinician.")
return steps
def tree_of_hypotheses(symptoms: List[str], top_n: int = 5) -> List[str]:
text = ", ".join(symptoms)
buckets = {
"cardio": ["chest pain","palpitations","syncope","shortness of breath"],
"neuro": ["headache","weakness","numbness","confusion","seizure"],
"gi": ["abdominal pain","nausea","vomiting","diarrhea","constipation"],
"pulm": ["cough","wheezing","dyspnea","hemoptysis"],
"id": ["fever","chills","night sweats","fatigue"]
}
matched = []
for system, keys in buckets.items():
if any(k in text for k in keys):
matched.append(system)
if not matched:
matched = ["general"]
return [f"Hypothesis branch: {m}" for m in matched]
def infer_differential(symptoms_text: str, top_k: int = 7) -> Tuple[str, list, list, list]:
if _startup_error:
return f"Startup error: {_startup_error}", [], [], []
cleaned = normalize_symptoms(symptoms_text)
tokens = [t for t in cleaned.split(",") if t]
if len(tokens) < 3:
return "Enter at least 3 symptoms (comma-separated).", [], [], []
query = ", ".join(tokens)
qv = _vectorizer.transform([query])
sims = cosine_similarity(qv, _X).ravel()
scores = {}
for i, s in enumerate(sims):
dx = str(_df.loc[i, _diagnosis_col])
scores[dx] = max(scores.get(dx, 0.0), float(s))
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
if ranked:
max_s = ranked[0][1] if ranked[0][1] > 0 else 1e-9
ranked = [(dx, round(score/max_s, 3)) for dx, score in ranked]
diffs = [f"• {dx} — score {score}" for dx, score in ranked]
fewshot = few_shot_prompt(tokens, k_examples=3)
chain = chain_of_reasoning(tokens)
tree = tree_of_hypotheses(tokens)
return "", diffs, chain, [fewshot] + tree
disclaimer = (
"Educational use only. Not medical advice.\n"
"Do not use this tool for emergencies.\n"
"If symptoms are severe or worsening, seek licensed care immediately."
)
with gr.Blocks(title="AI-Powered Differential Diagnosis (Educational)") as demo:
gr.Markdown("# AI-Powered Differential Diagnosis (Educational)")
gr.Markdown(disclaimer)
if _startup_error:
gr.Markdown(f"**Startup error:** {_startup_error}")
with gr.Row():
inp = gr.Textbox(label="Enter symptoms (comma-separated). Minimum 3.", placeholder="fever, cough, shortness of breath")
with gr.Row():
btn = gr.Button("Analyze")
clr = gr.ClearButton([inp])
with gr.Row():
dx = gr.Markdown(label="Differential Diagnoses")
with gr.Row():
chain_box = gr.Markdown(label="Structured reasoning steps")
with gr.Row():
prompt_box = gr.Markdown(label="Few-shot examples + hypothesis branches")
def on_analyze(text):
err, diffs, chain, prompt = infer_differential(text)
if err:
return err, "", ""
dx_md = "\n".join(diffs) if diffs else "No matches."
chain_md = "\n".join(chain) if chain else ""
prompt_md = "\n".join(prompt) if prompt else ""
return dx_md, chain_md, prompt_md
btn.click(on_analyze, inputs=[inp], outputs=[dx, chain_box, prompt_box])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|