File size: 6,392 Bytes
aca21d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

import gradio as gr
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple

DATASET_PATH = "differential_diagnosis_dataset.csv"

def load_dataset(path: str):
    df = pd.read_csv(path)
    # Identify diagnosis column (case-insensitive exact match or contains)
    cols_lower = [c.lower() for c in df.columns]
    if "diagnosis" in cols_lower:
        diagnosis_col = df.columns[cols_lower.index("diagnosis")]
    else:
        diagnosis_col = None
        for c in df.columns:
            if "diagnos" in c.lower():
                diagnosis_col = c
                break
        if diagnosis_col is None:
            raise ValueError("No diagnosis column found. Please ensure a 'diagnosis' column exists.")
    # Symptom columns heuristic
    symptom_cols = [c for c in df.columns if ("symptom" in c.lower()) or ("feature" in c.lower()) or ("sign" in c.lower())]
    if len(symptom_cols) == 0:
        symptom_cols = [c for c in df.columns if c != diagnosis_col]
    # Build a 'text' column by concatenating symptoms
    def row_to_text(row):
        parts = []
        for c in symptom_cols:
            val = row.get(c, "")
            if pd.notna(val) and str(val).strip():
                parts.append(str(val))
        return ", ".join(parts)
    df = df.copy()
    df["_symptom_text"] = df.apply(row_to_text, axis=1)
    return df, diagnosis_col, symptom_cols

# Global fit (loaded at startup)
try:
    _df, _diagnosis_col, _symptom_cols = load_dataset(DATASET_PATH)
    _corpus = _df["_symptom_text"].fillna("").astype(str).tolist()
    _vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1)
    _X = _vectorizer.fit_transform(_corpus)
except Exception as e:
    _df, _diagnosis_col, _symptom_cols = None, None, None
    _corpus, _vectorizer, _X = [], None, None
    _startup_error = str(e)
else:
    _startup_error = ""

def normalize_symptoms(text: str) -> str:
    return ", ".join([t.strip().lower() for t in text.split(",") if t.strip()])

def few_shot_prompt(symptoms: List[str], k_examples: int = 3) -> str:
    if _X is None or _vectorizer is None:
        return "Model not initialized."
    query = ", ".join(symptoms)
    qv = _vectorizer.transform([query])
    sims = cosine_similarity(qv, _X).ravel()
    top_idx = sims.argsort()[::-1][:k_examples]
    examples = []
    for i in top_idx:
        examples.append(f"Symptoms: {{ {_df.loc[i, '_symptom_text']} }} -> Diagnosis: {{ {_df.loc[i, _diagnosis_col]} }}")
    return "\n".join(examples)

def chain_of_reasoning(symptoms: List[str]) -> List[str]:
    steps = []
    steps.append("1) Parse input and standardize terms.")
    steps.append("2) Match symptom pattern with dataset using TF-IDF similarity.")
    steps.append("3) Collect top diagnoses and compute normalized scores.")
    steps.append("4) Check for red-flag patterns (e.g., chest pain + shortness of breath).")
    steps.append("5) Return differential list with triage flags. For care, consult a clinician.")
    return steps

def tree_of_hypotheses(symptoms: List[str], top_n: int = 5) -> List[str]:
    text = ", ".join(symptoms)
    buckets = {
        "cardio": ["chest pain","palpitations","syncope","shortness of breath"],
        "neuro": ["headache","weakness","numbness","confusion","seizure"],
        "gi": ["abdominal pain","nausea","vomiting","diarrhea","constipation"],
        "pulm": ["cough","wheezing","dyspnea","hemoptysis"],
        "id": ["fever","chills","night sweats","fatigue"]
    }
    matched = []
    for system, keys in buckets.items():
        if any(k in text for k in keys):
            matched.append(system)
    if not matched:
        matched = ["general"]
    return [f"Hypothesis branch: {m}" for m in matched]

def infer_differential(symptoms_text: str, top_k: int = 7) -> Tuple[str, list, list, list]:
    if _startup_error:
        return f"Startup error: {_startup_error}", [], [], []
    cleaned = normalize_symptoms(symptoms_text)
    tokens = [t for t in cleaned.split(",") if t]
    if len(tokens) < 3:
        return "Enter at least 3 symptoms (comma-separated).", [], [], []

    query = ", ".join(tokens)
    qv = _vectorizer.transform([query])
    sims = cosine_similarity(qv, _X).ravel()
    scores = {}
    for i, s in enumerate(sims):
        dx = str(_df.loc[i, _diagnosis_col])
        scores[dx] = max(scores.get(dx, 0.0), float(s))

    ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
    if ranked:
        max_s = ranked[0][1] if ranked[0][1] > 0 else 1e-9
        ranked = [(dx, round(score/max_s, 3)) for dx, score in ranked]

    diffs = [f"• {dx} — score {score}" for dx, score in ranked]
    fewshot = few_shot_prompt(tokens, k_examples=3)
    chain = chain_of_reasoning(tokens)
    tree = tree_of_hypotheses(tokens)

    return "", diffs, chain, [fewshot] + tree

disclaimer = (
    "Educational use only. Not medical advice.\n"
    "Do not use this tool for emergencies.\n"
    "If symptoms are severe or worsening, seek licensed care immediately."
)

with gr.Blocks(title="AI-Powered Differential Diagnosis (Educational)") as demo:
    gr.Markdown("# AI-Powered Differential Diagnosis (Educational)")
    gr.Markdown(disclaimer)
    if _startup_error:
        gr.Markdown(f"**Startup error:** {_startup_error}")
    with gr.Row():
        inp = gr.Textbox(label="Enter symptoms (comma-separated). Minimum 3.", placeholder="fever, cough, shortness of breath")
    with gr.Row():
        btn = gr.Button("Analyze")
        clr = gr.ClearButton([inp])
    with gr.Row():
        dx = gr.Markdown(label="Differential Diagnoses")
    with gr.Row():
        chain_box = gr.Markdown(label="Structured reasoning steps")
    with gr.Row():
        prompt_box = gr.Markdown(label="Few-shot examples + hypothesis branches")

    def on_analyze(text):
        err, diffs, chain, prompt = infer_differential(text)
        if err:
            return err, "", ""
        dx_md = "\n".join(diffs) if diffs else "No matches."
        chain_md = "\n".join(chain) if chain else ""
        prompt_md = "\n".join(prompt) if prompt else ""
        return dx_md, chain_md, prompt_md

    btn.click(on_analyze, inputs=[inp], outputs=[dx, chain_box, prompt_box])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)