import gradio as gr from huggingface_hub import hf_hub_download import pickle import numpy as np import os REPO_ID = "Umranz/mediscan-symptom-classifier" def load_models(): files = ["svm.pkl", "logistic.pkl", "random_forest.pkl", "naive_bayes.pkl", "voting_ensemble.pkl", "label_encoder.pkl", "tfidf.pkl"] loaded = {} for f in files: path = hf_hub_download(repo_id=REPO_ID, filename=f) with open(path, "rb") as file: loaded[f.replace(".pkl", "")] = pickle.load(file) return loaded print("Loading models...") M = load_models() tfidf = M["tfidf"] le = M["label_encoder"] ensemble = M["voting_ensemble"] models = { "SVM" : M["svm"], "Logistic Reg" : M["logistic"], "Random Forest" : M["random_forest"], "Naive Bayes" : M["naive_bayes"], } print("āœ… Models loaded!") SEVERITY = { "Fungal infection" : ("🟔", "Mild"), "Allergy" : ("🟔", "Mild"), "GERD" : ("🟔", "Mild"), "Chronic cholestasis" : ("🟠", "Moderate"), "Drug Reaction" : ("🟠", "Moderate"), "Peptic ulcer disease" : ("🟠", "Moderate"), "AIDS" : ("šŸ”“", "Severe"), "Diabetes" : ("🟠", "Moderate"), "Gastroenteritis" : ("🟔", "Mild"), "Bronchial Asthma" : ("🟠", "Moderate"), "Hypertension" : ("šŸ”“", "Severe"), "Migraine" : ("🟔", "Mild"), "Cervical spondylosis" : ("🟔", "Mild"), "Paralysis (brain hemorrhage)": ("šŸ”“", "Severe"), "Jaundice" : ("🟠", "Moderate"), "Malaria" : ("šŸ”“", "Severe"), "Chicken pox" : ("🟔", "Mild"), "Dengue" : ("šŸ”“", "Severe"), "Typhoid" : ("🟠", "Moderate"), "hepatitis A" : ("🟠", "Moderate"), "Hepatitis B" : ("šŸ”“", "Severe"), "Hepatitis C" : ("šŸ”“", "Severe"), "Hepatitis D" : ("šŸ”“", "Severe"), "Hepatitis E" : ("🟠", "Moderate"), "Alcoholic hepatitis" : ("🟠", "Moderate"), "Tuberculosis" : ("šŸ”“", "Severe"), "Common Cold" : ("🟢", "Low"), "Pneumonia" : ("šŸ”“", "Severe"), "Dimorphic hemmorhoids(piles)": ("🟔", "Mild"), "Heart attack" : ("šŸ”“", "Critical"), "Varicose veins" : ("🟔", "Mild"), "Hypothyroidism" : ("🟠", "Moderate"), "Hyperthyroidism" : ("🟠", "Moderate"), "Hypoglycemia" : ("šŸ”“", "Severe"), "Osteoarthristis" : ("🟔", "Mild"), "Arthritis" : ("🟔", "Mild"), "Vertigo" : ("🟔", "Mild"), "Acne" : ("🟢", "Low"), "Urinary tract infection" : ("🟔", "Mild"), "Psoriasis" : ("🟔", "Mild"), "Impetigo" : ("🟔", "Mild"), } def predict(symptoms, threshold): if not symptoms.strip(): return ( "āš ļø Please enter your symptoms.", "", "", "" ) vec = tfidf.transform([symptoms]) proba = ensemble.predict_proba(vec)[0] top3 = np.argsort(proba)[::-1][:3] top_idx = top3[0] top_label = le.classes_[top_idx] top_conf = proba[top_idx] * 100 sev_emoji, sev_label = SEVERITY.get(top_label, ("⚪", "Unknown")) if top_conf < threshold: main_result = ( f"āš ļø **Low Confidence ({top_conf:.1f}%)** — Please provide more specific symptoms.\n\n" f"Best guess: **{top_label}** but confidence is below your threshold of {threshold}%." ) return main_result, "", "", "" else: main_result = ( f"## {sev_emoji} {top_label}\n" f"**Confidence:** {top_conf:.1f}%\n\n" f"**Severity:** {sev_emoji} {sev_label}\n\n" f"{'ā–ˆ' * int(top_conf // 5)}{'ā–‘' * (20 - int(top_conf // 5))} {top_conf:.1f}%" ) top3_result = "## šŸ† Top 3 Predictions\n\n" for rank, idx in enumerate(top3): label = le.classes_[idx] conf = proba[idx] * 100 s_emoji, s_label = SEVERITY.get(label, ("⚪", "Unknown")) bar = "ā–ˆ" * int(conf // 5) + "ā–‘" * (20 - int(conf // 5)) top3_result += ( f"**{rank+1}. {label}** {s_emoji} {s_label}\n" f"{bar} {conf:.1f}%\n\n" ) agreement = "## šŸ¤– Model Votes\n\n" votes = {} for name, model in models.items(): pred = le.classes_[model.predict(vec)[0]] votes[name] = pred match = "āœ…" if pred == top_label else "šŸ”„" agreement += f"{match} **{name}** → {pred}\n\n" all_agree = len(set(votes.values())) == 1 agreement += ( "\n🟢 **All models agree!**" if all_agree else "\n🟔 **Models have different opinions — consider consulting a doctor.**" ) disclaimer = ( "## āš ļø Medical Disclaimer\n\n" "This tool is for **educational purposes only** and does **NOT** replace " "professional medical advice. Always consult a qualified healthcare provider " "for diagnosis and treatment.\n\n" "**If you have a medical emergency, call your local emergency number immediately.**" ) return main_result, top3_result, agreement, disclaimer EXAMPLES = [ ["fever, chills, headache, muscle pain, sweating", 50], ["itching, skin rash, nodal skin eruptions, dischromic patches", 50], ["chest pain, shortness of breath, fatigue, sweating", 50], ["sneezing, runny nose, cough, sore throat, congestion", 50], ["fatigue, weight loss, high fever, night sweats, cough", 50], ] with gr.Blocks(title="MediScan AI") as demo: gr.Markdown(""" # 🩺 MediScan AI — Medical Symptom Classifier **4 ML Models + Voting Ensemble** | DistilBERT-level accuracy with traditional ML > Enter your symptoms separated by commas for instant multi-model analysis. """) with gr.Row(): with gr.Column(scale=2): symptoms_input = gr.Textbox( lines=4, placeholder="e.g. fever, chills, headache, muscle pain, fatigue...", label="šŸ” Describe Your Symptoms", max_lines=8 ) threshold_slider = gr.Slider( minimum=10, maximum=90, value=50, step=5, label="āš™ļø Confidence Threshold (%)", info="Predictions below this % will show a low-confidence warning" ) analyze_btn = gr.Button( "šŸ” Analyze Symptoms", variant="primary", size="lg" ) with gr.Column(scale=3): main_output = gr.Markdown(label="Primary Diagnosis") with gr.Row(): top3_output = gr.Markdown(label="Top 3 Predictions") agreement_output = gr.Markdown(label="Model Agreement") disclaimer_output = gr.Markdown() gr.Examples( examples=EXAMPLES, inputs=[symptoms_input, threshold_slider], label="šŸ’” Try These Examples" ) with gr.Accordion("ā„¹ļø About MediScan AI", open=False): gr.Markdown(""" ## 🧠 How It Works MediScan AI runs your symptoms through **4 independent ML models simultaneously:** | Model | Strength | |---|---| | **SVM** | Best accuracy on text classification | | **Logistic Regression** | Fast, reliable baseline | | **Random Forest** | Handles noisy input well | | **Naive Bayes** | Great for keyword-based symptoms | A **Soft Voting Ensemble** combines all 4 predictions for the final result. ## šŸ“Š Dataset - **Source:** Gretel AI Symptom to Diagnosis dataset - **Diseases:** 24 unique conditions - **Features:** TF-IDF with bigrams (5000 features) ## šŸ‘Øā€šŸ’» Built By Umranz — [HuggingFace Profile](https://huggingface.co/Umranz) """) analyze_btn.click( fn=predict, inputs=[symptoms_input, threshold_slider], outputs=[main_output, top3_output, agreement_output, disclaimer_output] ) symptoms_input.submit( fn=predict, inputs=[symptoms_input, threshold_slider], outputs=[main_output, top3_output, agreement_output, disclaimer_output] ) demo.launch()