Man0707 commited on
Commit
046d8d3
·
verified ·
1 Parent(s): 5a5f884

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +85 -144
src/streamlit_app.py CHANGED
@@ -1,194 +1,135 @@
1
  # mushroom_app.py
2
  import streamlit as st
3
  import pandas as pd
4
- import numpy as np
5
- import joblib
6
  import requests
7
  from io import StringIO
8
  from sklearn.model_selection import train_test_split
9
- from sklearn.preprocessing import LabelEncoder
10
  from sklearn.ensemble import RandomForestClassifier
11
- from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
 
15
  # -------------------------
16
- # Auto Download Mushroom Dataset
17
  # -------------------------
18
- @st.cache_data(show_spinner="Downloading Mushroom Dataset from UCI...")
19
- def load_mushroom_data():
20
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
21
  response = requests.get(url)
22
- if response.status_code == 200:
23
- # Column names as per UCI description
24
- columns = [
25
- 'class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
26
- 'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
27
- 'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',
28
- 'stalk-surface-below-ring', 'stalk-color-above-ring',
29
- 'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',
30
- 'ring-type', 'spore-print-color', 'population', 'habitat'
31
- ]
32
- df = pd.read_csv(StringIO(response.text), header=None, names=columns)
33
- return df
34
- else:
35
- st.error("Failed to download dataset.")
36
- return None
37
-
38
- # Load data
39
- df = load_mushroom_data()
40
- if df is None:
41
- st.stop()
42
-
43
- st.set_page_config(page_title="Mushroom Classification", layout="centered")
44
  st.title("Mushroom Classification")
45
- st.markdown("### Is it *Edible* or *Poisonous*?")
46
- st.success(f"Dataset loaded: {df.shape[0]:,} mushrooms | {df.shape[1]} features")
47
-
48
- # Show sample
49
- st.write("First 10 mushrooms:")
50
- st.dataframe(df.head(10), use_container_width=True)
51
 
52
- # Mapping
53
- class_map = {'e': 'Edible', 'p': 'Poisonous'}
54
- st.write("*Target Distribution:*")
55
  col1, col2 = st.columns(2)
56
  with col1:
57
- st.metric("Edible (e)", df['class'].value_counts().get('e', 0))
58
  with col2:
59
- st.metric("Poisonous (p)", df['class'].value_counts().get('p', 0))
60
 
61
- # Encode categorical features
62
  @st.cache_data
63
- def preprocess_data(df):
64
  le_dict = {}
65
- df_encoded = df.copy()
66
-
67
- for column in df.columns:
68
  le = LabelEncoder()
69
- df_encoded[column] = le.fit_transform(df[column])
70
- le_dict[column] = le
71
-
72
- X = df_encoded.drop('class', axis=1)
73
- y = df_encoded['class']
74
-
75
- return X, y, le_dict, df_encoded
76
-
77
- X, y, label_encoders, df_encoded = preprocess_data(df)
78
 
79
- # Train-test split
80
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
81
 
82
- # -------------------------
83
  # Train Model
84
- # -------------------------
85
- st.header("Train Random Forest Model (Best for this dataset)")
86
- if st.button("Train Model Now", type="primary"):
87
- with st.spinner("Training Random Forest... (this dataset is 100% separable!)"):
88
  model = RandomForestClassifier(n_estimators=100, random_state=42)
89
  model.fit(X_train, y_train)
90
 
91
- y_pred = model.predict(X_test)
92
- acc = accuracy_score(y_test, y_pred)
93
-
94
- st.success(f"Model Trained! Accuracy: {acc*100:.2f}%")
95
-
96
- if acc >= 0.999:
97
  st.balloons()
98
- st.markdown("### PERFECT CLASSIFICATION! (100% Accuracy)")
99
-
100
- # Confusion Matrix
101
- cm = confusion_matrix(y_test, y_pred)
102
- fig, ax = plt.subplots(figsize=(6,5))
103
- sns.heatmap(cm, annot=True, fmt='d', cmap="Greens", ax=ax,
104
- xticklabels=['Edible', 'Poisonous'],
105
- yticklabels=['Edible', 'Poisonous'])
106
- ax.set_xlabel("Predicted")
107
- ax.set_ylabel("Actual")
108
- ax.set_title("Confusion Matrix")
109
- st.pyplot(fig)
110
-
111
- # Classification Report
112
- report = classification_report(y_test, y_pred, target_names=['Edible', 'Poisonous'], output_dict=True)
113
- report_df = pd.DataFrame(report).transpose()
114
- st.write("### Classification Report")
115
- st.dataframe(report_df.style.format("{:.3f}"))
116
 
117
  # Save model
118
- joblib.dump({
119
- "model": model,
120
- "label_encoders": label_encoders,
121
- "features": X.columns.tolist()
122
- }, "mushroom_model.pkl")
123
-
124
- with open("mushroom_model.pkl", "rb") as f:
125
- st.download_button("Download Model (.pkl)", f, "mushroom_model.pkl")
126
 
127
- # -------------------------
128
- # Single Prediction (Perfect & Safe)
129
- # -------------------------
130
- st.header("Will You Eat This Mushroom?")
131
- st.markdown("Select features of a mushroom to predict if it's *safe to eat*")
132
-
133
- # Load model if trained
134
  model = None
135
- if 'model' in globals() or st.session_state.get("model_trained", False):
136
- try:
137
- loaded = joblib.load("mushroom_model.pkl")
138
- model = loaded["model"]
139
- label_encoders = loaded["label_encoders"]
140
- st.session_state.model_trained = True
141
- except:
142
- pass
143
-
 
 
144
  if model is None:
145
- st.info("Train the model above first to make predictions!")
146
  else:
147
  cols = st.columns(3)
148
  inputs = {}
149
 
150
- feature_descriptions = {
151
- 'cap-shape': ['bell', 'conical', 'flat', 'knobbed', 'sunken', 'convex'],
152
- 'cap-surface': ['fibrous', 'grooves', 'smooth', 'scaly'],
153
- 'cap-color': ['buff', 'cinnamon', 'red', 'gray', 'brown', 'pink', 'green', 'purple', 'white', 'yellow'],
154
- 'bruises': ['yes', 'no'],
155
- 'odor': ['almond', 'creosote', 'foul', 'anise', 'musty', 'none', 'pungent', 'spicy', 'fishy'],
156
- 'gill-color': ['buff', 'red', 'gray', 'chocolate', 'black', 'brown', 'orange', 'pink', 'green', 'purple', 'white', 'yellow'],
157
- 'stalk-shape': ['enlarging', 'tapering'],
158
- 'stalk-root': ['bulbous', 'club', 'equal', 'rooted', 'missing'],
159
- 'spore-print-color': ['black', 'brown', 'buff', 'chocolate', 'green', 'orange', 'purple', 'white', 'yellow'],
160
  'population': ['abundant', 'clustered', 'numerous', 'scattered', 'several', 'solitary'],
161
- 'habitat': ['woods', 'grasses', 'leaves', 'meadows', 'paths', 'urban', 'waste']
162
  }
163
 
164
- selected = {}
165
  for i, col in enumerate(X.columns):
 
166
  with cols[i % 3]:
167
- options = feature_descriptions.get(col, ['?'])
168
- idx = st.selectbox(f"{col.replace('-', ' ').title()}", options, key=col)
169
- # Map to encoded value
170
- original_values = label_encoders[col].classes_
171
- if idx in original_values:
172
- encoded_val = label_encoders[col].transform([idx])[0]
173
- selected[col] = encoded_val
174
 
175
- if st.button("Predict Edible or Poisonous?", type="secondary"):
176
- input_vec = [selected[f] for f in X.columns]
177
- input_df = pd.DataFrame([input_vec], columns=X.columns)
178
-
179
- pred = model.predict(input_df)[0]
180
- prob = model.predict_proba(input_df)[0]
181
-
182
- if pred == label_encoders['class'].transform(['e'])[0]:
183
- st.success("EDIBLE & SAFE TO EAT!")
184
  st.balloons()
185
  else:
186
- st.error("POISONOUS! DO NOT EAT!")
187
- st.warning("This mushroom is highly toxic!")
188
 
189
  col1, col2 = st.columns(2)
190
- col1.metric("Edible Probability", f"{prob[0]:.1%}")
191
- col2.metric("Poisonous Probability", f"{prob[1]:.1%}")
192
 
193
  st.markdown("---")
194
- st.caption("Mushroom Classification • UCI Dataset • 100% Accuracy Possible Built with Streamlit")
 
1
  # mushroom_app.py
2
  import streamlit as st
3
  import pandas as pd
 
 
4
  import requests
5
  from io import StringIO
6
  from sklearn.model_selection import train_test_split
 
7
  from sklearn.ensemble import RandomForestClassifier
8
+ from sklearn.preprocessing import LabelEncoder
9
+ import joblib
10
+ import os
11
 
12
  # -------------------------
13
+ # Auto Load Dataset
14
  # -------------------------
15
+ @st.cache_data
16
+ def load_data():
17
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
18
  response = requests.get(url)
19
+ response.raise_for_status()
20
+
21
+ columns = [
22
+ 'class', 'cap_shape', 'cap_surface', 'cap_color', 'bruises', 'odor',
23
+ 'gill_attachment', 'gill_spacing', 'gill_size', 'gill_color',
24
+ 'stalk_shape', 'stalk_root', 'stalk_surface_above_ring',
25
+ 'stalk_surface_below_ring', 'stalk_color_above_ring',
26
+ 'stalk_color_below_ring', 'veil_type', 'veil_color', 'ring_number',
27
+ 'ring_type', 'spore_print_color', 'population', 'habitat'
28
+ ]
29
+ df = pd.read_csv(StringIO(response.text), header=None, names=columns)
30
+ return df
31
+
32
+ df = load_data()
33
+
34
+ st.set_page_config(page_title="Mushroom Classifier", layout="centered")
 
 
 
 
 
 
35
  st.title("Mushroom Classification")
36
+ st.markdown("### *Edible* or *Poisonous*? Let AI decide!")
 
 
 
 
 
37
 
38
+ # Show data
39
+ st.write(f"{len(df):,} mushrooms loaded**")
 
40
  col1, col2 = st.columns(2)
41
  with col1:
42
+ st.metric("Edible", len(df[df['class'] == 'e']))
43
  with col2:
44
+ st.metric("Poisonous", len(df[df['class'] == 'p']))
45
 
46
+ # Preprocess
47
  @st.cache_data
48
+ def preprocess(df):
49
  le_dict = {}
50
+ df_enc = df.copy()
51
+ for col in df.columns:
 
52
  le = LabelEncoder()
53
+ df_enc[col] = le.fit_transform(df[col])
54
+ le_dict[col] = le
55
+ X = df_enc.drop('class', axis=1)
56
+ y = df_enc['class']
57
+ return X, y, le_dict
 
 
 
 
58
 
59
+ X, y, encoders = preprocess(df)
60
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
61
 
 
62
  # Train Model
63
+ if st.button("Train Model (100% Accuracy Possible)", type="primary"):
64
+ with st.spinner("Training Random Forest..."):
 
 
65
  model = RandomForestClassifier(n_estimators=100, random_state=42)
66
  model.fit(X_train, y_train)
67
 
68
+ acc = model.score(X_test, y_test)
69
+ st.success(f"Model Trained! Accuracy: {acc:.4%}")
70
+ if acc == 1.0:
 
 
 
71
  st.balloons()
72
+ st.markdown("*PERFECT CLASSIFICATION!*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Save model
75
+ joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
76
+ st.session_state.model = model
77
+ st.session_state.encoders = encoders
 
 
 
 
 
78
 
79
+ # Load model if exists
 
 
 
 
 
 
80
  model = None
81
+ encoders = None
82
+ if os.path.exists("model.pkl"):
83
+ data = joblib.load("model.pkl")
84
+ model = data["model"]
85
+ encoders = data["encoders"]
86
+ elif "model" in st.session_state:
87
+ model = st.session_state.model
88
+ encoders = st.session_state.encoders
89
+
90
+ # Prediction UI
91
+ st.header("Test a Mushroom")
92
  if model is None:
93
+ st.info("Click 'Train Model' above to enable predictions")
94
  else:
95
  cols = st.columns(3)
96
  inputs = {}
97
 
98
+ feature_options = {
99
+ 'cap_shape': ['bell', 'conical', 'convex', 'flat', 'knobbed', 'sunken'],
100
+ 'cap_surface': ['fibrous', 'grooves', 'scaly', 'smooth'],
101
+ 'cap_color': ['brown', 'buff', 'cinnamon', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow'],
102
+ 'bruises': ['bruises', 'no'],
103
+ 'odor': ['almond', 'anise', 'creosote', 'fishy', 'foul', 'musty', 'none', 'pungent', 'spicy'],
104
+ 'gill_color': ['black', 'brown', 'buff', 'chocolate', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'],
105
+ 'stalk_shape': ['enlarging', 'tapering'],
106
+ 'spore_print_color': ['black', 'brown', 'buff', 'chocolate', 'green', 'orange', 'purple', 'white', 'yellow'],
 
107
  'population': ['abundant', 'clustered', 'numerous', 'scattered', 'several', 'solitary'],
108
+ 'habitat': ['grasses', 'leaves', 'meadows', 'paths', 'urban', 'waste', 'woods']
109
  }
110
 
 
111
  for i, col in enumerate(X.columns):
112
+ options = feature_options.get(col, list(encoders[col].classes_))
113
  with cols[i % 3]:
114
+ val = st.selectbox(col.replace("_", " ").title(), options, key=col)
115
+ code = encoders[col].transform([val])[0]
116
+ inputs[col] = code
 
 
 
 
117
 
118
+ if st.button("Is it Safe to Eat?", type="secondary"):
119
+ input_vec = [[inputs[col] for col in X.columns]]
120
+ pred = model.predict(input_vec)[0]
121
+ prob = model.predict_proba(input_vec)[0]
122
+
123
+ if encoders['class'].inverse_transform([pred])[0] == 'e':
124
+ st.success("EDIBLE – Safe to Eat!")
 
 
125
  st.balloons()
126
  else:
127
+ st.error("POISONOUS DO NOT EAT!")
128
+ st.warning("Highly Toxic!")
129
 
130
  col1, col2 = st.columns(2)
131
+ col1.metric("Edible", f"{prob[0]:.1%}")
132
+ col2.metric("Poisonous", f"{prob[1]:.1%}")
133
 
134
  st.markdown("---")
135
+ st.caption("Mushroom Classifier • UCI Dataset • 100% Deployable on Hugging Face Spaces")