Man0707 commited on
Commit
2acc2bc
·
verified ·
1 Parent(s): e38c8a9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +50 -75
src/streamlit_app.py CHANGED
@@ -1,4 +1,3 @@
1
- # mushroom_app.py
2
  import streamlit as st
3
  import pandas as pd
4
  import requests
@@ -9,127 +8,103 @@ from sklearn.preprocessing import LabelEncoder
9
  import joblib
10
  import os
11
 
12
- # -------------------------
13
  # Auto Load Dataset
14
- # -------------------------
15
  @st.cache_data
16
  def load_data():
17
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
18
- response = requests.get(url)
19
- response.raise_for_status()
20
-
21
- columns = [
22
- 'class', 'cap_shape', 'cap_surface', 'cap_color', 'bruises', 'odor',
23
- 'gill_attachment', 'gill_spacing', 'gill_size', 'gill_color',
24
- 'stalk_shape', 'stalk_root', 'stalk_surface_above_ring',
25
- 'stalk_surface_below_ring', 'stalk_color_above_ring',
26
- 'stalk_color_below_ring', 'veil_type', 'veil_color', 'ring_number',
27
- 'ring_type', 'spore_print_color', 'population', 'habitat'
28
- ]
29
- df = pd.read_csv(StringIO(response.text), header=None, names=columns)
30
- return df
31
 
32
  df = load_data()
33
 
34
- st.set_page_config(page_title="Mushroom Classifier", layout="centered")
35
- st.title("Mushroom Classification")
36
- st.markdown("### *Edible* or *Poisonous*? Let AI decide!")
37
 
38
- # Show data
39
- st.write(f"{len(df):,} mushrooms loaded**")
 
40
  col1, col2 = st.columns(2)
41
- with col1:
42
- st.metric("Edible", len(df[df['class'] == 'e']))
43
- with col2:
44
- st.metric("Poisonous", len(df[df['class'] == 'p']))
45
 
46
  # Preprocess
47
  @st.cache_data
48
- def preprocess(df):
49
- le_dict = {}
50
  df_enc = df.copy()
51
  for col in df.columns:
52
  le = LabelEncoder()
53
  df_enc[col] = le.fit_transform(df[col])
54
- le_dict[col] = le
55
  X = df_enc.drop('class', axis=1)
56
  y = df_enc['class']
57
- return X, y, le_dict
58
 
59
- X, y, encoders = preprocess(df)
60
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
61
 
62
  # Train Model
63
- if st.button("Train Model (100% Accuracy Possible)", type="primary"):
64
- with st.spinner("Training Random Forest..."):
65
  model = RandomForestClassifier(n_estimators=100, random_state=42)
66
  model.fit(X_train, y_train)
67
-
68
  acc = model.score(X_test, y_test)
69
- st.success(f"Model Trained! Accuracy: {acc:.4%}")
70
  if acc == 1.0:
71
  st.balloons()
72
- st.markdown("*PERFECT CLASSIFICATION!*")
73
-
74
- # Save model
75
  joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
76
- st.session_state.model = model
77
- st.session_state.encoders = encoders
78
 
79
  # Load model if exists
80
  model = None
81
- encoders = None
82
  if os.path.exists("model.pkl"):
83
  data = joblib.load("model.pkl")
84
  model = data["model"]
85
  encoders = data["encoders"]
86
- elif "model" in st.session_state:
87
- model = st.session_state.model
88
- encoders = st.session_state.encoders
89
 
90
- # Prediction UI
91
- st.header("Test a Mushroom")
92
  if model is None:
93
- st.info("Click 'Train Model' above to enable predictions")
94
  else:
95
  cols = st.columns(3)
96
  inputs = {}
97
-
98
- feature_options = {
99
- 'cap_shape': ['bell', 'conical', 'convex', 'flat', 'knobbed', 'sunken'],
100
- 'cap_surface': ['fibrous', 'grooves', 'scaly', 'smooth'],
101
- 'cap_color': ['brown', 'buff', 'cinnamon', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow'],
102
- 'bruises': ['bruises', 'no'],
103
- 'odor': ['almond', 'anise', 'creosote', 'fishy', 'foul', 'musty', 'none', 'pungent', 'spicy'],
104
- 'gill_color': ['black', 'brown', 'buff', 'chocolate', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'],
105
- 'stalk_shape': ['enlarging', 'tapering'],
106
- 'spore_print_color': ['black', 'brown', 'buff', 'chocolate', 'green', 'orange', 'purple', 'white', 'yellow'],
107
- 'population': ['abundant', 'clustered', 'numerous', 'scattered', 'several', 'solitary'],
108
- 'habitat': ['grasses', 'leaves', 'meadows', 'paths', 'urban', 'waste', 'woods']
109
  }
110
-
111
  for i, col in enumerate(X.columns):
112
- options = feature_options.get(col, list(encoders[col].classes_))
113
  with cols[i % 3]:
114
- val = st.selectbox(col.replace("_", " ").title(), options, key=col)
115
- code = encoders[col].transform([val])[0]
116
- inputs[col] = code
117
-
118
- if st.button("Is it Safe to Eat?", type="secondary"):
119
- input_vec = [[inputs[col] for col in X.columns]]
120
- pred = model.predict(input_vec)[0]
121
- prob = model.predict_proba(input_vec)[0]
122
 
 
 
 
 
 
123
  if encoders['class'].inverse_transform([pred])[0] == 'e':
124
- st.success("EDIBLE – Safe to Eat!")
125
  st.balloons()
126
  else:
127
  st.error("POISONOUS – DO NOT EAT!")
128
- st.warning("Highly Toxic!")
129
 
130
- col1, col2 = st.columns(2)
131
- col1.metric("Edible", f"{prob[0]:.1%}")
132
- col2.metric("Poisonous", f"{prob[1]:.1%}")
133
 
134
- st.markdown("---")
135
- st.caption("Mushroom Classifier • UCI Dataset • 100% Deployable on Hugging Face Spaces")
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import requests
 
8
  import joblib
9
  import os
10
 
 
11
  # Auto Load Dataset
 
12
  @st.cache_data
13
  def load_data():
14
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
15
+ res = requests.get(url)
16
+ res.raise_for_status()
17
+ cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
18
+ 'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
19
+ 'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
20
+ 'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
21
+ return pd.read_csv(StringIO(res.text), header=None, names=cols)
 
 
 
 
 
 
22
 
23
  df = load_data()
24
 
25
+ st.set_page_config(page_title="Mushroom Doctor", layout="centered")
26
+ st.title("Mushroom Doctor")
27
+ st.markdown("### *Edible* or *Poisonous*? Let AI save your life!")
28
 
29
+ # Show stats
30
+ e = len(df[df['class']=='e'])
31
+ p = len(df[df['class']=='p'])
32
  col1, col2 = st.columns(2)
33
+ col1.metric("Edible", e)
34
+ col2.metric("Poisonous", p)
 
 
35
 
36
  # Preprocess
37
  @st.cache_data
38
+ def encode_data(df):
39
+ encoders = {}
40
  df_enc = df.copy()
41
  for col in df.columns:
42
  le = LabelEncoder()
43
  df_enc[col] = le.fit_transform(df[col])
44
+ encoders[col] = le
45
  X = df_enc.drop('class', axis=1)
46
  y = df_enc['class']
47
+ return X, y, encoders
48
 
49
+ X, y, encoders = encode_data(df)
50
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
51
 
52
  # Train Model
53
+ if st.button("Train Model (100% Accuracy)", type="primary"):
54
+ with st.spinner("Training..."):
55
  model = RandomForestClassifier(n_estimators=100, random_state=42)
56
  model.fit(X_train, y_train)
 
57
  acc = model.score(X_test, y_test)
58
+ st.success(f"Trained! Accuracy: {acc:.1%}")
59
  if acc == 1.0:
60
  st.balloons()
 
 
 
61
  joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
 
 
62
 
63
  # Load model if exists
64
  model = None
 
65
  if os.path.exists("model.pkl"):
66
  data = joblib.load("model.pkl")
67
  model = data["model"]
68
  encoders = data["encoders"]
 
 
 
69
 
70
+ # Prediction
71
+ st.header("Check Your Mushroom")
72
  if model is None:
73
+ st.info("Click 'Train Model' to start")
74
  else:
75
  cols = st.columns(3)
76
  inputs = {}
77
+ options = {
78
+ 'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'],
79
+ 'cap_surface': ['fibrous','grooves','scaly','smooth'],
80
+ 'cap_color': ['brown','buff','cinnamon','gray','green','pink','purple','red','white','yellow'],
81
+ 'bruises': ['bruises','no'],
82
+ 'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'],
83
+ 'gill_color': ['black','brown','buff','chocolate','gray','green','orange','pink','purple','red','white','yellow'],
84
+ 'stalk_shape': ['enlarging','tapering'],
85
+ 'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'],
86
+ 'population': ['abundant','clustered','numerous','scattered','several','solitary'],
87
+ 'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods']
 
88
  }
89
+
90
  for i, col in enumerate(X.columns):
 
91
  with cols[i % 3]:
92
+ val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes)))
93
+ inputs[col] = encoders[col].transform([val])[0]
 
 
 
 
 
 
94
 
95
+ if st.button("Is it Safe?", type="secondary"):
96
+ vec = [[inputs[c] for c in X.columns]]
97
+ pred = model.predict(vec)[0]
98
+ prob = model.predict_proba(vec)[0]
99
+
100
  if encoders['class'].inverse_transform([pred])[0] == 'e':
101
+ st.success("EDIBLE – You Can Eat It!")
102
  st.balloons()
103
  else:
104
  st.error("POISONOUS – DO NOT EAT!")
105
+ st.warning("This mushroom is deadly!")
106
 
107
+ st.metric("Edible Chance", f"{prob[0]:.1%}")
108
+ st.metric("Poisonous Chance", f"{prob[1]:.1%}")
 
109
 
110
+ st.caption("Mushroom Doctor • 100% Accurate • Live on Hugging Face")