File size: 3,594 Bytes
5682f2f
cac9233
63e6641
 
 
 
 
046d8d3
 
 
63e6641
5682f2f
 
 
 
046d8d3
 
63e6641
5682f2f
2acc2bc
 
 
 
5682f2f
046d8d3
 
5682f2f
046d8d3
5682f2f
 
 
 
 
63e6641
 
5682f2f
2acc2bc
5682f2f
046d8d3
63e6641
5682f2f
2acc2bc
5682f2f
 
2acc2bc
63e6641
5682f2f
 
63e6641
5682f2f
 
63e6641
 
5682f2f
 
 
046d8d3
63e6641
 
046d8d3
5682f2f
ea31469
 
046d8d3
5682f2f
 
 
63e6641
 
 
5682f2f
 
 
 
 
 
 
63e6641
 
 
5682f2f
2acc2bc
046d8d3
5682f2f
 
 
 
 
ea31469
5682f2f
63e6641
 
5682f2f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# app.py
import streamlit as st
import pandas as pd
import requests
from io import StringIO
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
import joblib
import os

st.set_page_config(page_title="Mushroom Doctor", layout="centered")
st.title("Mushroom Doctor")
st.markdown("### *Edible* or *Poisonous*? AI Knows!")

@st.cache_data
def load_data():
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
    r = requests.get(url)
    cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
            'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
            'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
            'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
    return pd.read_csv(StringIO(r.text), header=None, names=cols)

df = load_data()
st.success(f"Loaded {len(df):,} mushrooms")

edible = len(df[df['class']=='e'])
poison = len(df[df['class']=='p'])
c1, c2 = st.columns(2)
c1.metric("Edible", edible)
c2.metric("Poisonous", poison)

@st.cache_data
def prepare():
    encoders = {}
    df2 = df.copy()
    for col in df.columns:
        le = LabelEncoder()
        df2[col] = le.fit_transform(df[col])
        encoders[col] = le
    X = df2.drop('class', axis=1)
    y = df2['class']
    return X, y, encoders

X, y, encoders = prepare()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

if st.button("Train Model (100% Accuracy)", type="primary"):
    with st.spinner("Training..."):
        model = RandomForestClassifier(n_estimators=100, random_state=42)
        model.fit(X_train, y_train)
        acc = model.score(X_test, y_test)
        st.success(f"Trained! Accuracy: {acc:.1%}")
        if acc == 1.0: st.balloons()
        joblib.dump({"model": model, "encoders": encoders}, "model.pkl")

model = None
if os.path.exists("model.pkl"):
    data = joblib.dump.load("model.pkl")
    model = data["model"]
    encoders = data["encoders"]

st.header("Predict Mushroom")
if model is None:
    st.info("Train the model first!")
else:
    cols = st.columns(3)
    inputs = {}
    options = {
        'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'],
        'bruises': ['bruises','no'],
        'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'],
        'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'],
        'population': ['abundant','clustered','numerous','scattered','several','solitary'],
        'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods']
    }
    for i, col in enumerate(X.columns):
        with cols[i % 3]:
            val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes)))
            inputs[col] = encoders[col].transform([val])[0]

    if st.button("Is it Safe?", type="secondary"):
        vec = [[inputs[c] for c in X.columns]]
        pred = model.predict(vec)[0]
        prob = model.predict_proba(vec)[0]
        result = encoders['class'].inverse_transform([pred])[0]
        if result == 'e':
            st.success("EDIBLE – SAFE TO EAT!")
            st.balloons()
        else:
            st.error("POISONOUS – DO NOT EAT!")
        st.metric("Edible", f"{prob[0]:.1%}")
        st.metric("Poisonous", f"{prob[1]:.1%}")