MAntrA / src /app.py
Man0707's picture
Update src/app.py
5682f2f verified
# app.py
import streamlit as st
import pandas as pd
import requests
from io import StringIO
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
import joblib
import os
st.set_page_config(page_title="Mushroom Doctor", layout="centered")
st.title("Mushroom Doctor")
st.markdown("### *Edible* or *Poisonous*? AI Knows!")
@st.cache_data
def load_data():
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
r = requests.get(url)
cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
return pd.read_csv(StringIO(r.text), header=None, names=cols)
df = load_data()
st.success(f"Loaded {len(df):,} mushrooms")
edible = len(df[df['class']=='e'])
poison = len(df[df['class']=='p'])
c1, c2 = st.columns(2)
c1.metric("Edible", edible)
c2.metric("Poisonous", poison)
@st.cache_data
def prepare():
encoders = {}
df2 = df.copy()
for col in df.columns:
le = LabelEncoder()
df2[col] = le.fit_transform(df[col])
encoders[col] = le
X = df2.drop('class', axis=1)
y = df2['class']
return X, y, encoders
X, y, encoders = prepare()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
if st.button("Train Model (100% Accuracy)", type="primary"):
with st.spinner("Training..."):
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
acc = model.score(X_test, y_test)
st.success(f"Trained! Accuracy: {acc:.1%}")
if acc == 1.0: st.balloons()
joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
model = None
if os.path.exists("model.pkl"):
data = joblib.dump.load("model.pkl")
model = data["model"]
encoders = data["encoders"]
st.header("Predict Mushroom")
if model is None:
st.info("Train the model first!")
else:
cols = st.columns(3)
inputs = {}
options = {
'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'],
'bruises': ['bruises','no'],
'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'],
'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'],
'population': ['abundant','clustered','numerous','scattered','several','solitary'],
'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods']
}
for i, col in enumerate(X.columns):
with cols[i % 3]:
val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes)))
inputs[col] = encoders[col].transform([val])[0]
if st.button("Is it Safe?", type="secondary"):
vec = [[inputs[c] for c in X.columns]]
pred = model.predict(vec)[0]
prob = model.predict_proba(vec)[0]
result = encoders['class'].inverse_transform([pred])[0]
if result == 'e':
st.success("EDIBLE – SAFE TO EAT!")
st.balloons()
else:
st.error("POISONOUS – DO NOT EAT!")
st.metric("Edible", f"{prob[0]:.1%}")
st.metric("Poisonous", f"{prob[1]:.1%}")