File size: 3,642 Bytes
85a4d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed2351
85a4d35
 
 
5ed2351
85a4d35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import joblib
import numpy as np

# ✅ Load trained model
model = joblib.load("random_forest_model.pkl")

# ✅ Preprocessing function (basic, assuming inputs are already encoded correctly)
def preprocess(age, gender, bmi, activity, smoking, alcohol, diabetes, hypertension, 
               cholesterol, family_history, cognitive_score, depression, sleep_quality, 
               diet, apoe4, social_engagement, stress):

    # 🔢 Manual encoding just like training
    gender = 1 if gender == "Male" else 0
    activity_low = int(activity == "Low")
    activity_medium = int(activity == "Medium")
    smoking_former = int(smoking == "Former")
    smoking_never = int(smoking == "Never")
    alcohol_occasional = int(alcohol == "Occasionally")
    alcohol_regular = int(alcohol == "Regularly")
    cholesterol_normal = int(cholesterol == "Normal")
    depression_low = int(depression == "Low")
    depression_medium = int(depression == "Medium")
    sleep_good = int(sleep_quality == "Good")
    sleep_poor = int(sleep_quality == "Poor")
    diet_healthy = int(diet == "Healthy")
    diet_unhealthy = int(diet == "Unhealthy")
    social_low = int(social_engagement == "Low")
    social_medium = int(social_engagement == "Medium")
    stress_low = int(stress == "Low")
    stress_medium = int(stress == "Medium")

    features = [
        age, bmi, int(diabetes == "Yes"), int(hypertension == "Yes"),
        int(family_history == "Yes"), cognitive_score, int(apoe4 == "Yes"),
        gender, activity_low, activity_medium,
        smoking_former, smoking_never,
        alcohol_occasional, alcohol_regular,
        cholesterol_normal,
        depression_low, depression_medium,
        sleep_good, sleep_poor,
        diet_healthy, diet_unhealthy,
        social_low, social_medium,
        stress_low, stress_medium
    ]

    return np.array(features).reshape(1, -1)

# ✅ Prediction function
def predict(*inputs):
    features = preprocess(*inputs)
    proba = model.predict_proba(features)[0][1]
    percent = round(proba * 100, 2)

    if proba >= 0.5:
        return f"🧠 Likely Alzheimer’s Positive ({percent}%)"
    else:
        return f"✅ Likely Alzheimer’s Negative ({100 - percent}%)"

demo = gr.Interface(
    fn=predict,
    title="🧠 Alzheimer's Disease Prediction (Random Forest)",
    inputs=[
        gr.Slider(40, 100, value=60, label="Age"),
        gr.Radio(["Male", "Female"], label="Gender"),
        gr.Slider(15, 40, value=25, label="BMI"),
        gr.Radio(["Low", "Medium", "High"], label="Physical Activity Level"),
        gr.Radio(["Never", "Former", "Current"], label="Smoking Status"),
        gr.Radio(["Never", "Occasionally", "Regularly"], label="Alcohol Consumption"),
        gr.Radio(["No", "Yes"], label="Diabetes"),
        gr.Radio(["No", "Yes"], label="Hypertension"),
        gr.Radio(["Normal", "High", "Low"], label="Cholesterol Level"),
        gr.Radio(["No", "Yes"], label="Family History of Alzheimer’s"),
        gr.Slider(0, 100, value=60, label="Cognitive Test Score"),
        gr.Radio(["Low", "Medium", "High"], label="Depression Level"),
        gr.Radio(["Poor", "Average", "Good"], label="Sleep Quality"),
        gr.Radio(["Healthy", "Average", "Unhealthy"], label="Dietary Habits"),
        gr.Radio(["No", "Yes"], label="Genetic Risk Factor (APOE-ε4)"),
        gr.Radio(["Low", "Medium", "High"], label="Social Engagement Level"),
        gr.Radio(["Low", "Medium", "High"], label="Stress Levels"),
    ],
    outputs=gr.Text(label="Prediction Result"),  # ✅ Correct usage
    theme="default"
)


if __name__ == "__main__":
    demo.launch()