MediScan_AI / app.py
Umranz's picture
Update MediScan AI app with improved UI and fixed models loading
3d98fdb
import gradio as gr
from huggingface_hub import hf_hub_download
import pickle
import numpy as np
import os
REPO_ID = "Umranz/mediscan-symptom-classifier"
def load_models():
files = ["svm.pkl", "logistic.pkl", "random_forest.pkl", "naive_bayes.pkl", "voting_ensemble.pkl", "label_encoder.pkl", "tfidf.pkl"]
loaded = {}
for f in files:
path = hf_hub_download(repo_id=REPO_ID, filename=f)
with open(path, "rb") as file:
loaded[f.replace(".pkl", "")] = pickle.load(file)
return loaded
print("Loading models...")
M = load_models()
tfidf = M["tfidf"]
le = M["label_encoder"]
ensemble = M["voting_ensemble"]
models = {
"SVM" : M["svm"],
"Logistic Reg" : M["logistic"],
"Random Forest" : M["random_forest"],
"Naive Bayes" : M["naive_bayes"],
}
print("βœ… Models loaded!")
SEVERITY = {
"Fungal infection" : ("🟑", "Mild"),
"Allergy" : ("🟑", "Mild"),
"GERD" : ("🟑", "Mild"),
"Chronic cholestasis" : ("🟠", "Moderate"),
"Drug Reaction" : ("🟠", "Moderate"),
"Peptic ulcer disease" : ("🟠", "Moderate"),
"AIDS" : ("πŸ”΄", "Severe"),
"Diabetes" : ("🟠", "Moderate"),
"Gastroenteritis" : ("🟑", "Mild"),
"Bronchial Asthma" : ("🟠", "Moderate"),
"Hypertension" : ("πŸ”΄", "Severe"),
"Migraine" : ("🟑", "Mild"),
"Cervical spondylosis" : ("🟑", "Mild"),
"Paralysis (brain hemorrhage)": ("πŸ”΄", "Severe"),
"Jaundice" : ("🟠", "Moderate"),
"Malaria" : ("πŸ”΄", "Severe"),
"Chicken pox" : ("🟑", "Mild"),
"Dengue" : ("πŸ”΄", "Severe"),
"Typhoid" : ("🟠", "Moderate"),
"hepatitis A" : ("🟠", "Moderate"),
"Hepatitis B" : ("πŸ”΄", "Severe"),
"Hepatitis C" : ("πŸ”΄", "Severe"),
"Hepatitis D" : ("πŸ”΄", "Severe"),
"Hepatitis E" : ("🟠", "Moderate"),
"Alcoholic hepatitis" : ("🟠", "Moderate"),
"Tuberculosis" : ("πŸ”΄", "Severe"),
"Common Cold" : ("🟒", "Low"),
"Pneumonia" : ("πŸ”΄", "Severe"),
"Dimorphic hemmorhoids(piles)": ("🟑", "Mild"),
"Heart attack" : ("πŸ”΄", "Critical"),
"Varicose veins" : ("🟑", "Mild"),
"Hypothyroidism" : ("🟠", "Moderate"),
"Hyperthyroidism" : ("🟠", "Moderate"),
"Hypoglycemia" : ("πŸ”΄", "Severe"),
"Osteoarthristis" : ("🟑", "Mild"),
"Arthritis" : ("🟑", "Mild"),
"Vertigo" : ("🟑", "Mild"),
"Acne" : ("🟒", "Low"),
"Urinary tract infection" : ("🟑", "Mild"),
"Psoriasis" : ("🟑", "Mild"),
"Impetigo" : ("🟑", "Mild"),
}
def predict(symptoms, threshold):
if not symptoms.strip():
return (
"⚠️ Please enter your symptoms.",
"",
"",
""
)
vec = tfidf.transform([symptoms])
proba = ensemble.predict_proba(vec)[0]
top3 = np.argsort(proba)[::-1][:3]
top_idx = top3[0]
top_label = le.classes_[top_idx]
top_conf = proba[top_idx] * 100
sev_emoji, sev_label = SEVERITY.get(top_label, ("βšͺ", "Unknown"))
if top_conf < threshold:
main_result = (
f"⚠️ **Low Confidence ({top_conf:.1f}%)** β€” Please provide more specific symptoms.\n\n"
f"Best guess: **{top_label}** but confidence is below your threshold of {threshold}%."
)
return main_result, "", "", ""
else:
main_result = (
f"## {sev_emoji} {top_label}\n"
f"**Confidence:** {top_conf:.1f}%\n\n"
f"**Severity:** {sev_emoji} {sev_label}\n\n"
f"{'β–ˆ' * int(top_conf // 5)}{'β–‘' * (20 - int(top_conf // 5))} {top_conf:.1f}%"
)
top3_result = "## πŸ† Top 3 Predictions\n\n"
for rank, idx in enumerate(top3):
label = le.classes_[idx]
conf = proba[idx] * 100
s_emoji, s_label = SEVERITY.get(label, ("βšͺ", "Unknown"))
bar = "β–ˆ" * int(conf // 5) + "β–‘" * (20 - int(conf // 5))
top3_result += (
f"**{rank+1}. {label}** {s_emoji} {s_label}\n"
f"{bar} {conf:.1f}%\n\n"
)
agreement = "## πŸ€– Model Votes\n\n"
votes = {}
for name, model in models.items():
pred = le.classes_[model.predict(vec)[0]]
votes[name] = pred
match = "βœ…" if pred == top_label else "πŸ”„"
agreement += f"{match} **{name}** β†’ {pred}\n\n"
all_agree = len(set(votes.values())) == 1
agreement += (
"\n🟒 **All models agree!**" if all_agree
else "\n🟑 **Models have different opinions β€” consider consulting a doctor.**"
)
disclaimer = (
"## ⚠️ Medical Disclaimer\n\n"
"This tool is for **educational purposes only** and does **NOT** replace "
"professional medical advice. Always consult a qualified healthcare provider "
"for diagnosis and treatment.\n\n"
"**If you have a medical emergency, call your local emergency number immediately.**"
)
return main_result, top3_result, agreement, disclaimer
EXAMPLES = [
["fever, chills, headache, muscle pain, sweating", 50],
["itching, skin rash, nodal skin eruptions, dischromic patches", 50],
["chest pain, shortness of breath, fatigue, sweating", 50],
["sneezing, runny nose, cough, sore throat, congestion", 50],
["fatigue, weight loss, high fever, night sweats, cough", 50],
]
with gr.Blocks(title="MediScan AI") as demo:
gr.Markdown("""
# 🩺 MediScan AI β€” Medical Symptom Classifier
**4 ML Models + Voting Ensemble** | DistilBERT-level accuracy with traditional ML
> Enter your symptoms separated by commas for instant multi-model analysis.
""")
with gr.Row():
with gr.Column(scale=2):
symptoms_input = gr.Textbox(
lines=4,
placeholder="e.g. fever, chills, headache, muscle pain, fatigue...",
label="πŸ” Describe Your Symptoms",
max_lines=8
)
threshold_slider = gr.Slider(
minimum=10,
maximum=90,
value=50,
step=5,
label="βš™οΈ Confidence Threshold (%)",
info="Predictions below this % will show a low-confidence warning"
)
analyze_btn = gr.Button(
"πŸ” Analyze Symptoms",
variant="primary",
size="lg"
)
with gr.Column(scale=3):
main_output = gr.Markdown(label="Primary Diagnosis")
with gr.Row():
top3_output = gr.Markdown(label="Top 3 Predictions")
agreement_output = gr.Markdown(label="Model Agreement")
disclaimer_output = gr.Markdown()
gr.Examples(
examples=EXAMPLES,
inputs=[symptoms_input, threshold_slider],
label="πŸ’‘ Try These Examples"
)
with gr.Accordion("ℹ️ About MediScan AI", open=False):
gr.Markdown("""
## 🧠 How It Works
MediScan AI runs your symptoms through **4 independent ML models simultaneously:**
| Model | Strength |
|---|---|
| **SVM** | Best accuracy on text classification |
| **Logistic Regression** | Fast, reliable baseline |
| **Random Forest** | Handles noisy input well |
| **Naive Bayes** | Great for keyword-based symptoms |
A **Soft Voting Ensemble** combines all 4 predictions for the final result.
## πŸ“Š Dataset
- **Source:** Gretel AI Symptom to Diagnosis dataset
- **Diseases:** 24 unique conditions
- **Features:** TF-IDF with bigrams (5000 features)
## πŸ‘¨β€πŸ’» Built By
Umranz β€” [HuggingFace Profile](https://huggingface.co/Umranz)
""")
analyze_btn.click(
fn=predict,
inputs=[symptoms_input, threshold_slider],
outputs=[main_output, top3_output, agreement_output, disclaimer_output]
)
symptoms_input.submit(
fn=predict,
inputs=[symptoms_input, threshold_slider],
outputs=[main_output, top3_output, agreement_output, disclaimer_output]
)
demo.launch()