import gradio as gr import pandas as pd import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from typing import List, Tuple DATASET_PATH = "differential_diagnosis_dataset.csv" def load_dataset(path: str): df = pd.read_csv(path) # Identify diagnosis column (case-insensitive exact match or contains) cols_lower = [c.lower() for c in df.columns] if "diagnosis" in cols_lower: diagnosis_col = df.columns[cols_lower.index("diagnosis")] else: diagnosis_col = None for c in df.columns: if "diagnos" in c.lower(): diagnosis_col = c break if diagnosis_col is None: raise ValueError("No diagnosis column found. Please ensure a 'diagnosis' column exists.") # Symptom columns heuristic symptom_cols = [c for c in df.columns if ("symptom" in c.lower()) or ("feature" in c.lower()) or ("sign" in c.lower())] if len(symptom_cols) == 0: symptom_cols = [c for c in df.columns if c != diagnosis_col] # Build a 'text' column by concatenating symptoms def row_to_text(row): parts = [] for c in symptom_cols: val = row.get(c, "") if pd.notna(val) and str(val).strip(): parts.append(str(val)) return ", ".join(parts) df = df.copy() df["_symptom_text"] = df.apply(row_to_text, axis=1) return df, diagnosis_col, symptom_cols # Global fit (loaded at startup) try: _df, _diagnosis_col, _symptom_cols = load_dataset(DATASET_PATH) _corpus = _df["_symptom_text"].fillna("").astype(str).tolist() _vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1) _X = _vectorizer.fit_transform(_corpus) except Exception as e: _df, _diagnosis_col, _symptom_cols = None, None, None _corpus, _vectorizer, _X = [], None, None _startup_error = str(e) else: _startup_error = "" def normalize_symptoms(text: str) -> str: return ", ".join([t.strip().lower() for t in text.split(",") if t.strip()]) def few_shot_prompt(symptoms: List[str], k_examples: int = 3) -> str: if _X is None or _vectorizer is None: return "Model not initialized." query = ", ".join(symptoms) qv = _vectorizer.transform([query]) sims = cosine_similarity(qv, _X).ravel() top_idx = sims.argsort()[::-1][:k_examples] examples = [] for i in top_idx: examples.append(f"Symptoms: {{ {_df.loc[i, '_symptom_text']} }} -> Diagnosis: {{ {_df.loc[i, _diagnosis_col]} }}") return "\n".join(examples) def chain_of_reasoning(symptoms: List[str]) -> List[str]: steps = [] steps.append("1) Parse input and standardize terms.") steps.append("2) Match symptom pattern with dataset using TF-IDF similarity.") steps.append("3) Collect top diagnoses and compute normalized scores.") steps.append("4) Check for red-flag patterns (e.g., chest pain + shortness of breath).") steps.append("5) Return differential list with triage flags. For care, consult a clinician.") return steps def tree_of_hypotheses(symptoms: List[str], top_n: int = 5) -> List[str]: text = ", ".join(symptoms) buckets = { "cardio": ["chest pain","palpitations","syncope","shortness of breath"], "neuro": ["headache","weakness","numbness","confusion","seizure"], "gi": ["abdominal pain","nausea","vomiting","diarrhea","constipation"], "pulm": ["cough","wheezing","dyspnea","hemoptysis"], "id": ["fever","chills","night sweats","fatigue"] } matched = [] for system, keys in buckets.items(): if any(k in text for k in keys): matched.append(system) if not matched: matched = ["general"] return [f"Hypothesis branch: {m}" for m in matched] def infer_differential(symptoms_text: str, top_k: int = 7) -> Tuple[str, list, list, list]: if _startup_error: return f"Startup error: {_startup_error}", [], [], [] cleaned = normalize_symptoms(symptoms_text) tokens = [t for t in cleaned.split(",") if t] if len(tokens) < 3: return "Enter at least 3 symptoms (comma-separated).", [], [], [] query = ", ".join(tokens) qv = _vectorizer.transform([query]) sims = cosine_similarity(qv, _X).ravel() scores = {} for i, s in enumerate(sims): dx = str(_df.loc[i, _diagnosis_col]) scores[dx] = max(scores.get(dx, 0.0), float(s)) ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] if ranked: max_s = ranked[0][1] if ranked[0][1] > 0 else 1e-9 ranked = [(dx, round(score/max_s, 3)) for dx, score in ranked] diffs = [f"• {dx} — score {score}" for dx, score in ranked] fewshot = few_shot_prompt(tokens, k_examples=3) chain = chain_of_reasoning(tokens) tree = tree_of_hypotheses(tokens) return "", diffs, chain, [fewshot] + tree disclaimer = ( "Educational use only. Not medical advice.\n" "Do not use this tool for emergencies.\n" "If symptoms are severe or worsening, seek licensed care immediately." ) with gr.Blocks(title="AI-Powered Differential Diagnosis (Educational)") as demo: gr.Markdown("# AI-Powered Differential Diagnosis (Educational)") gr.Markdown(disclaimer) if _startup_error: gr.Markdown(f"**Startup error:** {_startup_error}") with gr.Row(): inp = gr.Textbox(label="Enter symptoms (comma-separated). Minimum 3.", placeholder="fever, cough, shortness of breath") with gr.Row(): btn = gr.Button("Analyze") clr = gr.ClearButton([inp]) with gr.Row(): dx = gr.Markdown(label="Differential Diagnoses") with gr.Row(): chain_box = gr.Markdown(label="Structured reasoning steps") with gr.Row(): prompt_box = gr.Markdown(label="Few-shot examples + hypothesis branches") def on_analyze(text): err, diffs, chain, prompt = infer_differential(text) if err: return err, "", "" dx_md = "\n".join(diffs) if diffs else "No matches." chain_md = "\n".join(chain) if chain else "" prompt_md = "\n".join(prompt) if prompt else "" return dx_md, chain_md, prompt_md btn.click(on_analyze, inputs=[inp], outputs=[dx, chain_box, prompt_box]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)