Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from typing import List, Tuple | |
| DATASET_PATH = "differential_diagnosis_dataset.csv" | |
| def load_dataset(path: str): | |
| df = pd.read_csv(path) | |
| # Identify diagnosis column (case-insensitive exact match or contains) | |
| cols_lower = [c.lower() for c in df.columns] | |
| if "diagnosis" in cols_lower: | |
| diagnosis_col = df.columns[cols_lower.index("diagnosis")] | |
| else: | |
| diagnosis_col = None | |
| for c in df.columns: | |
| if "diagnos" in c.lower(): | |
| diagnosis_col = c | |
| break | |
| if diagnosis_col is None: | |
| raise ValueError("No diagnosis column found. Please ensure a 'diagnosis' column exists.") | |
| # Symptom columns heuristic | |
| symptom_cols = [c for c in df.columns if ("symptom" in c.lower()) or ("feature" in c.lower()) or ("sign" in c.lower())] | |
| if len(symptom_cols) == 0: | |
| symptom_cols = [c for c in df.columns if c != diagnosis_col] | |
| # Build a 'text' column by concatenating symptoms | |
| def row_to_text(row): | |
| parts = [] | |
| for c in symptom_cols: | |
| val = row.get(c, "") | |
| if pd.notna(val) and str(val).strip(): | |
| parts.append(str(val)) | |
| return ", ".join(parts) | |
| df = df.copy() | |
| df["_symptom_text"] = df.apply(row_to_text, axis=1) | |
| return df, diagnosis_col, symptom_cols | |
| # Global fit (loaded at startup) | |
| try: | |
| _df, _diagnosis_col, _symptom_cols = load_dataset(DATASET_PATH) | |
| _corpus = _df["_symptom_text"].fillna("").astype(str).tolist() | |
| _vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1) | |
| _X = _vectorizer.fit_transform(_corpus) | |
| except Exception as e: | |
| _df, _diagnosis_col, _symptom_cols = None, None, None | |
| _corpus, _vectorizer, _X = [], None, None | |
| _startup_error = str(e) | |
| else: | |
| _startup_error = "" | |
| def normalize_symptoms(text: str) -> str: | |
| return ", ".join([t.strip().lower() for t in text.split(",") if t.strip()]) | |
| def few_shot_prompt(symptoms: List[str], k_examples: int = 3) -> str: | |
| if _X is None or _vectorizer is None: | |
| return "Model not initialized." | |
| query = ", ".join(symptoms) | |
| qv = _vectorizer.transform([query]) | |
| sims = cosine_similarity(qv, _X).ravel() | |
| top_idx = sims.argsort()[::-1][:k_examples] | |
| examples = [] | |
| for i in top_idx: | |
| examples.append(f"Symptoms: {{ {_df.loc[i, '_symptom_text']} }} -> Diagnosis: {{ {_df.loc[i, _diagnosis_col]} }}") | |
| return "\n".join(examples) | |
| def chain_of_reasoning(symptoms: List[str]) -> List[str]: | |
| steps = [] | |
| steps.append("1) Parse input and standardize terms.") | |
| steps.append("2) Match symptom pattern with dataset using TF-IDF similarity.") | |
| steps.append("3) Collect top diagnoses and compute normalized scores.") | |
| steps.append("4) Check for red-flag patterns (e.g., chest pain + shortness of breath).") | |
| steps.append("5) Return differential list with triage flags. For care, consult a clinician.") | |
| return steps | |
| def tree_of_hypotheses(symptoms: List[str], top_n: int = 5) -> List[str]: | |
| text = ", ".join(symptoms) | |
| buckets = { | |
| "cardio": ["chest pain","palpitations","syncope","shortness of breath"], | |
| "neuro": ["headache","weakness","numbness","confusion","seizure"], | |
| "gi": ["abdominal pain","nausea","vomiting","diarrhea","constipation"], | |
| "pulm": ["cough","wheezing","dyspnea","hemoptysis"], | |
| "id": ["fever","chills","night sweats","fatigue"] | |
| } | |
| matched = [] | |
| for system, keys in buckets.items(): | |
| if any(k in text for k in keys): | |
| matched.append(system) | |
| if not matched: | |
| matched = ["general"] | |
| return [f"Hypothesis branch: {m}" for m in matched] | |
| def infer_differential(symptoms_text: str, top_k: int = 7) -> Tuple[str, list, list, list]: | |
| if _startup_error: | |
| return f"Startup error: {_startup_error}", [], [], [] | |
| cleaned = normalize_symptoms(symptoms_text) | |
| tokens = [t for t in cleaned.split(",") if t] | |
| if len(tokens) < 3: | |
| return "Enter at least 3 symptoms (comma-separated).", [], [], [] | |
| query = ", ".join(tokens) | |
| qv = _vectorizer.transform([query]) | |
| sims = cosine_similarity(qv, _X).ravel() | |
| scores = {} | |
| for i, s in enumerate(sims): | |
| dx = str(_df.loc[i, _diagnosis_col]) | |
| scores[dx] = max(scores.get(dx, 0.0), float(s)) | |
| ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| if ranked: | |
| max_s = ranked[0][1] if ranked[0][1] > 0 else 1e-9 | |
| ranked = [(dx, round(score/max_s, 3)) for dx, score in ranked] | |
| diffs = [f"• {dx} — score {score}" for dx, score in ranked] | |
| fewshot = few_shot_prompt(tokens, k_examples=3) | |
| chain = chain_of_reasoning(tokens) | |
| tree = tree_of_hypotheses(tokens) | |
| return "", diffs, chain, [fewshot] + tree | |
| disclaimer = ( | |
| "Educational use only. Not medical advice.\n" | |
| "Do not use this tool for emergencies.\n" | |
| "If symptoms are severe or worsening, seek licensed care immediately." | |
| ) | |
| with gr.Blocks(title="AI-Powered Differential Diagnosis (Educational)") as demo: | |
| gr.Markdown("# AI-Powered Differential Diagnosis (Educational)") | |
| gr.Markdown(disclaimer) | |
| if _startup_error: | |
| gr.Markdown(f"**Startup error:** {_startup_error}") | |
| with gr.Row(): | |
| inp = gr.Textbox(label="Enter symptoms (comma-separated). Minimum 3.", placeholder="fever, cough, shortness of breath") | |
| with gr.Row(): | |
| btn = gr.Button("Analyze") | |
| clr = gr.ClearButton([inp]) | |
| with gr.Row(): | |
| dx = gr.Markdown(label="Differential Diagnoses") | |
| with gr.Row(): | |
| chain_box = gr.Markdown(label="Structured reasoning steps") | |
| with gr.Row(): | |
| prompt_box = gr.Markdown(label="Few-shot examples + hypothesis branches") | |
| def on_analyze(text): | |
| err, diffs, chain, prompt = infer_differential(text) | |
| if err: | |
| return err, "", "" | |
| dx_md = "\n".join(diffs) if diffs else "No matches." | |
| chain_md = "\n".join(chain) if chain else "" | |
| prompt_md = "\n".join(prompt) if prompt else "" | |
| return dx_md, chain_md, prompt_md | |
| btn.click(on_analyze, inputs=[inp], outputs=[dx, chain_box, prompt_box]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |