|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import requests |
|
|
from io import StringIO |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.ensemble import RandomForestClassifier |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
import joblib |
|
|
import os |
|
|
|
|
|
st.set_page_config(page_title="Mushroom Doctor", layout="centered") |
|
|
st.title("Mushroom Doctor") |
|
|
st.markdown("### *Edible* or *Poisonous*? AI Knows!") |
|
|
|
|
|
@st.cache_data |
|
|
def load_data(): |
|
|
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data" |
|
|
r = requests.get(url) |
|
|
cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment', |
|
|
'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring', |
|
|
'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type', |
|
|
'veil_color','ring_number','ring_type','spore_print_color','population','habitat'] |
|
|
return pd.read_csv(StringIO(r.text), header=None, names=cols) |
|
|
|
|
|
df = load_data() |
|
|
st.success(f"Loaded {len(df):,} mushrooms") |
|
|
|
|
|
edible = len(df[df['class']=='e']) |
|
|
poison = len(df[df['class']=='p']) |
|
|
c1, c2 = st.columns(2) |
|
|
c1.metric("Edible", edible) |
|
|
c2.metric("Poisonous", poison) |
|
|
|
|
|
@st.cache_data |
|
|
def prepare(): |
|
|
encoders = {} |
|
|
df2 = df.copy() |
|
|
for col in df.columns: |
|
|
le = LabelEncoder() |
|
|
df2[col] = le.fit_transform(df[col]) |
|
|
encoders[col] = le |
|
|
X = df2.drop('class', axis=1) |
|
|
y = df2['class'] |
|
|
return X, y, encoders |
|
|
|
|
|
X, y, encoders = prepare() |
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
|
|
|
|
|
if st.button("Train Model (100% Accuracy)", type="primary"): |
|
|
with st.spinner("Training..."): |
|
|
model = RandomForestClassifier(n_estimators=100, random_state=42) |
|
|
model.fit(X_train, y_train) |
|
|
acc = model.score(X_test, y_test) |
|
|
st.success(f"Trained! Accuracy: {acc:.1%}") |
|
|
if acc == 1.0: st.balloons() |
|
|
joblib.dump({"model": model, "encoders": encoders}, "model.pkl") |
|
|
|
|
|
model = None |
|
|
if os.path.exists("model.pkl"): |
|
|
data = joblib.dump.load("model.pkl") |
|
|
model = data["model"] |
|
|
encoders = data["encoders"] |
|
|
|
|
|
st.header("Predict Mushroom") |
|
|
if model is None: |
|
|
st.info("Train the model first!") |
|
|
else: |
|
|
cols = st.columns(3) |
|
|
inputs = {} |
|
|
options = { |
|
|
'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'], |
|
|
'bruises': ['bruises','no'], |
|
|
'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'], |
|
|
'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'], |
|
|
'population': ['abundant','clustered','numerous','scattered','several','solitary'], |
|
|
'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods'] |
|
|
} |
|
|
for i, col in enumerate(X.columns): |
|
|
with cols[i % 3]: |
|
|
val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes))) |
|
|
inputs[col] = encoders[col].transform([val])[0] |
|
|
|
|
|
if st.button("Is it Safe?", type="secondary"): |
|
|
vec = [[inputs[c] for c in X.columns]] |
|
|
pred = model.predict(vec)[0] |
|
|
prob = model.predict_proba(vec)[0] |
|
|
result = encoders['class'].inverse_transform([pred])[0] |
|
|
if result == 'e': |
|
|
st.success("EDIBLE – SAFE TO EAT!") |
|
|
st.balloons() |
|
|
else: |
|
|
st.error("POISONOUS – DO NOT EAT!") |
|
|
st.metric("Edible", f"{prob[0]:.1%}") |
|
|
st.metric("Poisonous", f"{prob[1]:.1%}") |