GradiotestApp / app.py
Konaguy's picture
Upload 4 files
aca21d8 verified
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple
DATASET_PATH = "differential_diagnosis_dataset.csv"
def load_dataset(path: str):
df = pd.read_csv(path)
# Identify diagnosis column (case-insensitive exact match or contains)
cols_lower = [c.lower() for c in df.columns]
if "diagnosis" in cols_lower:
diagnosis_col = df.columns[cols_lower.index("diagnosis")]
else:
diagnosis_col = None
for c in df.columns:
if "diagnos" in c.lower():
diagnosis_col = c
break
if diagnosis_col is None:
raise ValueError("No diagnosis column found. Please ensure a 'diagnosis' column exists.")
# Symptom columns heuristic
symptom_cols = [c for c in df.columns if ("symptom" in c.lower()) or ("feature" in c.lower()) or ("sign" in c.lower())]
if len(symptom_cols) == 0:
symptom_cols = [c for c in df.columns if c != diagnosis_col]
# Build a 'text' column by concatenating symptoms
def row_to_text(row):
parts = []
for c in symptom_cols:
val = row.get(c, "")
if pd.notna(val) and str(val).strip():
parts.append(str(val))
return ", ".join(parts)
df = df.copy()
df["_symptom_text"] = df.apply(row_to_text, axis=1)
return df, diagnosis_col, symptom_cols
# Global fit (loaded at startup)
try:
_df, _diagnosis_col, _symptom_cols = load_dataset(DATASET_PATH)
_corpus = _df["_symptom_text"].fillna("").astype(str).tolist()
_vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1)
_X = _vectorizer.fit_transform(_corpus)
except Exception as e:
_df, _diagnosis_col, _symptom_cols = None, None, None
_corpus, _vectorizer, _X = [], None, None
_startup_error = str(e)
else:
_startup_error = ""
def normalize_symptoms(text: str) -> str:
return ", ".join([t.strip().lower() for t in text.split(",") if t.strip()])
def few_shot_prompt(symptoms: List[str], k_examples: int = 3) -> str:
if _X is None or _vectorizer is None:
return "Model not initialized."
query = ", ".join(symptoms)
qv = _vectorizer.transform([query])
sims = cosine_similarity(qv, _X).ravel()
top_idx = sims.argsort()[::-1][:k_examples]
examples = []
for i in top_idx:
examples.append(f"Symptoms: {{ {_df.loc[i, '_symptom_text']} }} -> Diagnosis: {{ {_df.loc[i, _diagnosis_col]} }}")
return "\n".join(examples)
def chain_of_reasoning(symptoms: List[str]) -> List[str]:
steps = []
steps.append("1) Parse input and standardize terms.")
steps.append("2) Match symptom pattern with dataset using TF-IDF similarity.")
steps.append("3) Collect top diagnoses and compute normalized scores.")
steps.append("4) Check for red-flag patterns (e.g., chest pain + shortness of breath).")
steps.append("5) Return differential list with triage flags. For care, consult a clinician.")
return steps
def tree_of_hypotheses(symptoms: List[str], top_n: int = 5) -> List[str]:
text = ", ".join(symptoms)
buckets = {
"cardio": ["chest pain","palpitations","syncope","shortness of breath"],
"neuro": ["headache","weakness","numbness","confusion","seizure"],
"gi": ["abdominal pain","nausea","vomiting","diarrhea","constipation"],
"pulm": ["cough","wheezing","dyspnea","hemoptysis"],
"id": ["fever","chills","night sweats","fatigue"]
}
matched = []
for system, keys in buckets.items():
if any(k in text for k in keys):
matched.append(system)
if not matched:
matched = ["general"]
return [f"Hypothesis branch: {m}" for m in matched]
def infer_differential(symptoms_text: str, top_k: int = 7) -> Tuple[str, list, list, list]:
if _startup_error:
return f"Startup error: {_startup_error}", [], [], []
cleaned = normalize_symptoms(symptoms_text)
tokens = [t for t in cleaned.split(",") if t]
if len(tokens) < 3:
return "Enter at least 3 symptoms (comma-separated).", [], [], []
query = ", ".join(tokens)
qv = _vectorizer.transform([query])
sims = cosine_similarity(qv, _X).ravel()
scores = {}
for i, s in enumerate(sims):
dx = str(_df.loc[i, _diagnosis_col])
scores[dx] = max(scores.get(dx, 0.0), float(s))
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
if ranked:
max_s = ranked[0][1] if ranked[0][1] > 0 else 1e-9
ranked = [(dx, round(score/max_s, 3)) for dx, score in ranked]
diffs = [f"• {dx} — score {score}" for dx, score in ranked]
fewshot = few_shot_prompt(tokens, k_examples=3)
chain = chain_of_reasoning(tokens)
tree = tree_of_hypotheses(tokens)
return "", diffs, chain, [fewshot] + tree
disclaimer = (
"Educational use only. Not medical advice.\n"
"Do not use this tool for emergencies.\n"
"If symptoms are severe or worsening, seek licensed care immediately."
)
with gr.Blocks(title="AI-Powered Differential Diagnosis (Educational)") as demo:
gr.Markdown("# AI-Powered Differential Diagnosis (Educational)")
gr.Markdown(disclaimer)
if _startup_error:
gr.Markdown(f"**Startup error:** {_startup_error}")
with gr.Row():
inp = gr.Textbox(label="Enter symptoms (comma-separated). Minimum 3.", placeholder="fever, cough, shortness of breath")
with gr.Row():
btn = gr.Button("Analyze")
clr = gr.ClearButton([inp])
with gr.Row():
dx = gr.Markdown(label="Differential Diagnoses")
with gr.Row():
chain_box = gr.Markdown(label="Structured reasoning steps")
with gr.Row():
prompt_box = gr.Markdown(label="Few-shot examples + hypothesis branches")
def on_analyze(text):
err, diffs, chain, prompt = infer_differential(text)
if err:
return err, "", ""
dx_md = "\n".join(diffs) if diffs else "No matches."
chain_md = "\n".join(chain) if chain else ""
prompt_md = "\n".join(prompt) if prompt else ""
return dx_md, chain_md, prompt_md
btn.click(on_analyze, inputs=[inp], outputs=[dx, chain_box, prompt_box])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)