File size: 3,594 Bytes
5682f2f cac9233 63e6641 046d8d3 63e6641 5682f2f 046d8d3 63e6641 5682f2f 2acc2bc 5682f2f 046d8d3 5682f2f 046d8d3 5682f2f 63e6641 5682f2f 2acc2bc 5682f2f 046d8d3 63e6641 5682f2f 2acc2bc 5682f2f 2acc2bc 63e6641 5682f2f 63e6641 5682f2f 63e6641 5682f2f 046d8d3 63e6641 046d8d3 5682f2f ea31469 046d8d3 5682f2f 63e6641 5682f2f 63e6641 5682f2f 2acc2bc 046d8d3 5682f2f ea31469 5682f2f 63e6641 5682f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# app.py
import streamlit as st
import pandas as pd
import requests
from io import StringIO
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
import joblib
import os
st.set_page_config(page_title="Mushroom Doctor", layout="centered")
st.title("Mushroom Doctor")
st.markdown("### *Edible* or *Poisonous*? AI Knows!")
@st.cache_data
def load_data():
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
r = requests.get(url)
cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
return pd.read_csv(StringIO(r.text), header=None, names=cols)
df = load_data()
st.success(f"Loaded {len(df):,} mushrooms")
edible = len(df[df['class']=='e'])
poison = len(df[df['class']=='p'])
c1, c2 = st.columns(2)
c1.metric("Edible", edible)
c2.metric("Poisonous", poison)
@st.cache_data
def prepare():
encoders = {}
df2 = df.copy()
for col in df.columns:
le = LabelEncoder()
df2[col] = le.fit_transform(df[col])
encoders[col] = le
X = df2.drop('class', axis=1)
y = df2['class']
return X, y, encoders
X, y, encoders = prepare()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
if st.button("Train Model (100% Accuracy)", type="primary"):
with st.spinner("Training..."):
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
acc = model.score(X_test, y_test)
st.success(f"Trained! Accuracy: {acc:.1%}")
if acc == 1.0: st.balloons()
joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
model = None
if os.path.exists("model.pkl"):
data = joblib.dump.load("model.pkl")
model = data["model"]
encoders = data["encoders"]
st.header("Predict Mushroom")
if model is None:
st.info("Train the model first!")
else:
cols = st.columns(3)
inputs = {}
options = {
'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'],
'bruises': ['bruises','no'],
'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'],
'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'],
'population': ['abundant','clustered','numerous','scattered','several','solitary'],
'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods']
}
for i, col in enumerate(X.columns):
with cols[i % 3]:
val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes)))
inputs[col] = encoders[col].transform([val])[0]
if st.button("Is it Safe?", type="secondary"):
vec = [[inputs[c] for c in X.columns]]
pred = model.predict(vec)[0]
prob = model.predict_proba(vec)[0]
result = encoders['class'].inverse_transform([pred])[0]
if result == 'e':
st.success("EDIBLE – SAFE TO EAT!")
st.balloons()
else:
st.error("POISONOUS – DO NOT EAT!")
st.metric("Edible", f"{prob[0]:.1%}")
st.metric("Poisonous", f"{prob[1]:.1%}") |