Man0707 commited on
Commit
5682f2f
ยท
verified ยท
1 Parent(s): 6538db6

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +46 -67
src/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import pandas as pd
3
  import requests
@@ -8,109 +9,87 @@ from sklearn.preprocessing import LabelEncoder
8
  import joblib
9
  import os
10
 
11
- # Load Dataset
 
 
 
12
  @st.cache_data
13
  def load_data():
14
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
15
- res = requests.get(url)
16
- res.raise_for_status()
17
  cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
18
  'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
19
  'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
20
  'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
21
- return pd.read_csv(StringIO(res.text), header=None, names=cols)
22
 
23
  df = load_data()
24
- st.set_page_config(page_title="Mushroom Doctor", layout="centered")
25
- st.title("๐Ÿ„ Mushroom Doctor")
26
- st.markdown("### *Edible* or *Poisonous*? AI Will Tell You!")
27
 
28
- # Stats
29
- edible = len(df[df['class'] == 'e'])
30
- poison = len(df[df['class'] == 'p'])
31
- col1, col2 = st.columns(2)
32
- col1.metric("๐Ÿ„ Edible", edible)
33
- col2.metric("โ˜  Poisonous", poison)
34
 
35
- # Preprocess
36
  @st.cache_data
37
- def encode_data(df):
38
  encoders = {}
39
- df_enc = df.copy()
40
  for col in df.columns:
41
  le = LabelEncoder()
42
- df_enc[col] = le.fit_transform(df[col])
43
  encoders[col] = le
44
- X = df_enc.drop('class', axis=1)
45
- y = df_enc['class']
46
  return X, y, encoders
47
 
48
- X, y, encoders = encode_data(df)
49
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
50
 
51
- # Train
52
- if st.button("๐Ÿš€ Train Model (100% Accuracy!)", type="primary"):
53
- with st.spinner("Training Random Forest..."):
54
  model = RandomForestClassifier(n_estimators=100, random_state=42)
55
  model.fit(X_train, y_train)
56
- accuracy = model.score(X_test, y_test)
57
- st.success(f"โœ… Trained! Accuracy: {accuracy:.1%}")
58
- if accuracy == 1.0:
59
- st.balloons()
60
- st.markdown("๐ŸŽ‰ PERFECT โ€“ 100% Accurate!")
61
  joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
62
- st.session_state.model_trained = True
63
 
64
- # Load Model
65
  model = None
66
  if os.path.exists("model.pkl"):
67
- data = joblib.load("model.pkl")
68
  model = data["model"]
69
  encoders = data["encoders"]
70
- st.session_state.model_trained = True
71
 
72
- # Predict
73
- st.header("๐Ÿงช Test Your Mushroom")
74
- if not st.session_state.get("model_trained", False):
75
- st.info("๐Ÿ‘† Click 'Train Model' first!")
76
  else:
77
  cols = st.columns(3)
78
  inputs = {}
79
- feature_options = {
80
- 'cap_shape': ['bell', 'conical', 'convex', 'flat', 'knobbed', 'sunken'],
81
- 'cap_surface': ['fibrous', 'grooves', 'scaly', 'smooth'],
82
- 'cap_color': ['brown', 'buff', 'cinnamon', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow'],
83
- 'bruises': ['bruises', 'no'],
84
- 'odor': ['almond', 'anise', 'creosote', 'fishy', 'foul', 'musty', 'none', 'pungent', 'spicy'],
85
- 'gill_color': ['black', 'brown', 'buff', 'chocolate', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'],
86
- 'stalk_shape': ['enlarging', 'tapering'],
87
- 'spore_print_color': ['black', 'brown', 'buff', 'chocolate', 'green', 'orange', 'purple', 'white', 'yellow'],
88
- 'population': ['abundant', 'clustered', 'numerous', 'scattered', 'several', 'solitary'],
89
- 'habitat': ['grasses', 'leaves', 'meadows', 'paths', 'urban', 'waste', 'woods']
90
  }
91
-
92
  for i, col in enumerate(X.columns):
93
  with cols[i % 3]:
94
- opts = feature_options.get(col, list(encoders[col].classes_))
95
- val = st.selectbox(col.replace('_', ' ').title(), opts, key=f"{col}_sel")
96
  inputs[col] = encoders[col].transform([val])[0]
97
 
98
- if st.button("๐Ÿ”ฎ Predict: Safe or Deadly?", type="secondary"):
99
- input_data = [[inputs[col] for col in X.columns]]
100
- prediction = model.predict(input_data)[0]
101
- probs = model.predict_proba(input_data)[0]
102
-
103
- result = encoders['class'].inverse_transform([prediction])[0]
104
  if result == 'e':
105
- st.success("๐Ÿ„ *EDIBLE โ€“ SAFE TO EAT!* ๐ŸŽ‰")
106
  st.balloons()
107
  else:
108
- st.error("โ˜  *POISONOUS โ€“ DO NOT EAT!* โš ")
109
- st.warning("This could be fatal!")
110
-
111
- col1, col2 = st.columns(2)
112
- col1.metric("Safe to Eat", f"{probs[0]:.1%}")
113
- col2.metric("Dangerous", f"{probs[1]:.1%}")
114
-
115
- st.markdown("---")
116
- st.caption("๐Ÿ„ Mushroom Doctor | UCI Dataset | Powered by Streamlit & Hugging Face")
 
1
+ # app.py
2
  import streamlit as st
3
  import pandas as pd
4
  import requests
 
9
  import joblib
10
  import os
11
 
12
+ st.set_page_config(page_title="Mushroom Doctor", layout="centered")
13
+ st.title("Mushroom Doctor")
14
+ st.markdown("### *Edible* or *Poisonous*? AI Knows!")
15
+
16
  @st.cache_data
17
  def load_data():
18
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
19
+ r = requests.get(url)
 
20
  cols = ['class','cap_shape','cap_surface','cap_color','bruises','odor','gill_attachment',
21
  'gill_spacing','gill_size','gill_color','stalk_shape','stalk_root','stalk_surface_above_ring',
22
  'stalk_surface_below_ring','stalk_color_above_ring','stalk_color_below_ring','veil_type',
23
  'veil_color','ring_number','ring_type','spore_print_color','population','habitat']
24
+ return pd.read_csv(StringIO(r.text), header=None, names=cols)
25
 
26
  df = load_data()
27
+ st.success(f"Loaded {len(df):,} mushrooms")
 
 
28
 
29
+ edible = len(df[df['class']=='e'])
30
+ poison = len(df[df['class']=='p'])
31
+ c1, c2 = st.columns(2)
32
+ c1.metric("Edible", edible)
33
+ c2.metric("Poisonous", poison)
 
34
 
 
35
  @st.cache_data
36
+ def prepare():
37
  encoders = {}
38
+ df2 = df.copy()
39
  for col in df.columns:
40
  le = LabelEncoder()
41
+ df2[col] = le.fit_transform(df[col])
42
  encoders[col] = le
43
+ X = df2.drop('class', axis=1)
44
+ y = df2['class']
45
  return X, y, encoders
46
 
47
+ X, y, encoders = prepare()
48
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
49
 
50
+ if st.button("Train Model (100% Accuracy)", type="primary"):
51
+ with st.spinner("Training..."):
 
52
  model = RandomForestClassifier(n_estimators=100, random_state=42)
53
  model.fit(X_train, y_train)
54
+ acc = model.score(X_test, y_test)
55
+ st.success(f"Trained! Accuracy: {acc:.1%}")
56
+ if acc == 1.0: st.balloons()
 
 
57
  joblib.dump({"model": model, "encoders": encoders}, "model.pkl")
 
58
 
 
59
  model = None
60
  if os.path.exists("model.pkl"):
61
+ data = joblib.dump.load("model.pkl")
62
  model = data["model"]
63
  encoders = data["encoders"]
 
64
 
65
+ st.header("Predict Mushroom")
66
+ if model is None:
67
+ st.info("Train the model first!")
 
68
  else:
69
  cols = st.columns(3)
70
  inputs = {}
71
+ options = {
72
+ 'cap_shape': ['bell','conical','convex','flat','knobbed','sunken'],
73
+ 'bruises': ['bruises','no'],
74
+ 'odor': ['almond','anise','creosote','fishy','foul','musty','none','pungent','spicy'],
75
+ 'spore_print_color': ['black','brown','buff','chocolate','green','orange','purple','white','yellow'],
76
+ 'population': ['abundant','clustered','numerous','scattered','several','solitary'],
77
+ 'habitat': ['grasses','leaves','meadows','paths','urban','waste','woods']
 
 
 
 
78
  }
 
79
  for i, col in enumerate(X.columns):
80
  with cols[i % 3]:
81
+ val = st.selectbox(col.replace(""," ").title(), options.get(col, list(encoders[col].classes)))
 
82
  inputs[col] = encoders[col].transform([val])[0]
83
 
84
+ if st.button("Is it Safe?", type="secondary"):
85
+ vec = [[inputs[c] for c in X.columns]]
86
+ pred = model.predict(vec)[0]
87
+ prob = model.predict_proba(vec)[0]
88
+ result = encoders['class'].inverse_transform([pred])[0]
 
89
  if result == 'e':
90
+ st.success("EDIBLE โ€“ SAFE TO EAT!")
91
  st.balloons()
92
  else:
93
+ st.error("POISONOUS โ€“ DO NOT EAT!")
94
+ st.metric("Edible", f"{prob[0]:.1%}")
95
+ st.metric("Poisonous", f"{prob[1]:.1%}")