Man0707 commited on
Commit
63e6641
·
verified ·
1 Parent(s): fc728ce

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +192 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,194 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # mushroom_app.py
 
 
2
  import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import joblib
6
+ import requests
7
+ from io import StringIO
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+
15
+ # -------------------------
16
+ # Auto Download Mushroom Dataset
17
+ # -------------------------
18
+ @st.cache_data(show_spinner="Downloading Mushroom Dataset from UCI...")
19
+ def load_mushroom_data():
20
+ url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
21
+ response = requests.get(url)
22
+ if response.status_code == 200:
23
+ # Column names as per UCI description
24
+ columns = [
25
+ 'class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
26
+ 'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
27
+ 'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',
28
+ 'stalk-surface-below-ring', 'stalk-color-above-ring',
29
+ 'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',
30
+ 'ring-type', 'spore-print-color', 'population', 'habitat'
31
+ ]
32
+ df = pd.read_csv(StringIO(response.text), header=None, names=columns)
33
+ return df
34
+ else:
35
+ st.error("Failed to download dataset.")
36
+ return None
37
+
38
+ # Load data
39
+ df = load_mushroom_data()
40
+ if df is None:
41
+ st.stop()
42
+
43
+ st.set_page_config(page_title="Mushroom Classification", layout="centered")
44
+ st.title("Mushroom Classification")
45
+ st.markdown("### Is it *Edible* or *Poisonous*?")
46
+ st.success(f"Dataset loaded: {df.shape[0]:,} mushrooms | {df.shape[1]} features")
47
+
48
+ # Show sample
49
+ st.write("First 10 mushrooms:")
50
+ st.dataframe(df.head(10), use_container_width=True)
51
+
52
+ # Mapping
53
+ class_map = {'e': 'Edible', 'p': 'Poisonous'}
54
+ st.write("*Target Distribution:*")
55
+ col1, col2 = st.columns(2)
56
+ with col1:
57
+ st.metric("Edible (e)", df['class'].value_counts().get('e', 0))
58
+ with col2:
59
+ st.metric("Poisonous (p)", df['class'].value_counts().get('p', 0))
60
+
61
+ # Encode categorical features
62
+ @st.cache_data
63
+ def preprocess_data(df):
64
+ le_dict = {}
65
+ df_encoded = df.copy()
66
+
67
+ for column in df.columns:
68
+ le = LabelEncoder()
69
+ df_encoded[column] = le.fit_transform(df[column])
70
+ le_dict[column] = le
71
+
72
+ X = df_encoded.drop('class', axis=1)
73
+ y = df_encoded['class']
74
+
75
+ return X, y, le_dict, df_encoded
76
+
77
+ X, y, label_encoders, df_encoded = preprocess_data(df)
78
+
79
+ # Train-test split
80
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
81
+
82
+ # -------------------------
83
+ # Train Model
84
+ # -------------------------
85
+ st.header("Train Random Forest Model (Best for this dataset)")
86
+ if st.button("Train Model Now", type="primary"):
87
+ with st.spinner("Training Random Forest... (this dataset is 100% separable!)"):
88
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
89
+ model.fit(X_train, y_train)
90
+
91
+ y_pred = model.predict(X_test)
92
+ acc = accuracy_score(y_test, y_pred)
93
+
94
+ st.success(f"Model Trained! Accuracy: {acc*100:.2f}%")
95
+
96
+ if acc >= 0.999:
97
+ st.balloons()
98
+ st.markdown("### PERFECT CLASSIFICATION! (100% Accuracy)")
99
+
100
+ # Confusion Matrix
101
+ cm = confusion_matrix(y_test, y_pred)
102
+ fig, ax = plt.subplots(figsize=(6,5))
103
+ sns.heatmap(cm, annot=True, fmt='d', cmap="Greens", ax=ax,
104
+ xticklabels=['Edible', 'Poisonous'],
105
+ yticklabels=['Edible', 'Poisonous'])
106
+ ax.set_xlabel("Predicted")
107
+ ax.set_ylabel("Actual")
108
+ ax.set_title("Confusion Matrix")
109
+ st.pyplot(fig)
110
+
111
+ # Classification Report
112
+ report = classification_report(y_test, y_pred, target_names=['Edible', 'Poisonous'], output_dict=True)
113
+ report_df = pd.DataFrame(report).transpose()
114
+ st.write("### Classification Report")
115
+ st.dataframe(report_df.style.format("{:.3f}"))
116
+
117
+ # Save model
118
+ joblib.dump({
119
+ "model": model,
120
+ "label_encoders": label_encoders,
121
+ "features": X.columns.tolist()
122
+ }, "mushroom_model.pkl")
123
+
124
+ with open("mushroom_model.pkl", "rb") as f:
125
+ st.download_button("Download Model (.pkl)", f, "mushroom_model.pkl")
126
+
127
+ # -------------------------
128
+ # Single Prediction (Perfect & Safe)
129
+ # -------------------------
130
+ st.header("Will You Eat This Mushroom?")
131
+ st.markdown("Select features of a mushroom to predict if it's *safe to eat*")
132
+
133
+ # Load model if trained
134
+ model = None
135
+ if 'model' in globals() or st.session_state.get("model_trained", False):
136
+ try:
137
+ loaded = joblib.load("mushroom_model.pkl")
138
+ model = loaded["model"]
139
+ label_encoders = loaded["label_encoders"]
140
+ st.session_state.model_trained = True
141
+ except:
142
+ pass
143
+
144
+ if model is None:
145
+ st.info("Train the model above first to make predictions!")
146
+ else:
147
+ cols = st.columns(3)
148
+ inputs = {}
149
+
150
+ feature_descriptions = {
151
+ 'cap-shape': ['bell', 'conical', 'flat', 'knobbed', 'sunken', 'convex'],
152
+ 'cap-surface': ['fibrous', 'grooves', 'smooth', 'scaly'],
153
+ 'cap-color': ['buff', 'cinnamon', 'red', 'gray', 'brown', 'pink', 'green', 'purple', 'white', 'yellow'],
154
+ 'bruises': ['yes', 'no'],
155
+ 'odor': ['almond', 'creosote', 'foul', 'anise', 'musty', 'none', 'pungent', 'spicy', 'fishy'],
156
+ 'gill-color': ['buff', 'red', 'gray', 'chocolate', 'black', 'brown', 'orange', 'pink', 'green', 'purple', 'white', 'yellow'],
157
+ 'stalk-shape': ['enlarging', 'tapering'],
158
+ 'stalk-root': ['bulbous', 'club', 'equal', 'rooted', 'missing'],
159
+ 'spore-print-color': ['black', 'brown', 'buff', 'chocolate', 'green', 'orange', 'purple', 'white', 'yellow'],
160
+ 'population': ['abundant', 'clustered', 'numerous', 'scattered', 'several', 'solitary'],
161
+ 'habitat': ['woods', 'grasses', 'leaves', 'meadows', 'paths', 'urban', 'waste']
162
+ }
163
+
164
+ selected = {}
165
+ for i, col in enumerate(X.columns):
166
+ with cols[i % 3]:
167
+ options = feature_descriptions.get(col, ['?'])
168
+ idx = st.selectbox(f"{col.replace('-', ' ').title()}", options, key=col)
169
+ # Map to encoded value
170
+ original_values = label_encoders[col].classes_
171
+ if idx in original_values:
172
+ encoded_val = label_encoders[col].transform([idx])[0]
173
+ selected[col] = encoded_val
174
+
175
+ if st.button("Predict Edible or Poisonous?", type="secondary"):
176
+ input_vec = [selected[f] for f in X.columns]
177
+ input_df = pd.DataFrame([input_vec], columns=X.columns)
178
+
179
+ pred = model.predict(input_df)[0]
180
+ prob = model.predict_proba(input_df)[0]
181
+
182
+ if pred == label_encoders['class'].transform(['e'])[0]:
183
+ st.success("EDIBLE & SAFE TO EAT!")
184
+ st.balloons()
185
+ else:
186
+ st.error("POISONOUS! DO NOT EAT!")
187
+ st.warning("This mushroom is highly toxic!")
188
+
189
+ col1, col2 = st.columns(2)
190
+ col1.metric("Edible Probability", f"{prob[0]:.1%}")
191
+ col2.metric("Poisonous Probability", f"{prob[1]:.1%}")
192
 
193
+ st.markdown("---")
194
+ st.caption("Mushroom Classification UCI Dataset • 100% Accuracy Possible • Built with Streamlit")