Muthuraja18 commited on
Commit
8a8ec36
Β·
verified Β·
1 Parent(s): 1fff88a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -29
app.py CHANGED
@@ -38,12 +38,9 @@ st.set_page_config(
38
  )
39
 
40
  # -----------------------------
41
- # MODEL TRAINING FUNCTION
42
  # -----------------------------
43
  def train_and_save_model():
44
- """
45
- Train CNN model if model file doesn't exist
46
- """
47
  if not os.path.exists(DATASET_DIR):
48
  st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
49
  st.stop()
@@ -55,12 +52,15 @@ def train_and_save_model():
55
  validation_split=0.2
56
  )
57
 
 
 
58
  train_data = datagen.flow_from_directory(
59
  DATASET_DIR,
60
  target_size=IMG_SIZE,
61
  batch_size=BATCH_SIZE,
62
  class_mode='categorical',
63
- subset='training'
 
64
  )
65
 
66
  val_data = datagen.flow_from_directory(
@@ -68,21 +68,29 @@ def train_and_save_model():
68
  target_size=IMG_SIZE,
69
  batch_size=BATCH_SIZE,
70
  class_mode='categorical',
71
- subset='validation'
 
72
  )
73
 
74
- # CNN Architecture
 
 
 
 
 
 
 
75
  model = models.Sequential([
76
- layers.Input(shape=(128,128,3)),
77
 
78
- layers.Conv2D(32, (3,3), activation='relu'),
79
- layers.MaxPooling2D(2,2),
80
 
81
- layers.Conv2D(64, (3,3), activation='relu'),
82
- layers.MaxPooling2D(2,2),
83
 
84
- layers.Conv2D(128, (3,3), activation='relu'),
85
- layers.MaxPooling2D(2,2),
86
 
87
  layers.Flatten(),
88
 
@@ -92,13 +100,13 @@ def train_and_save_model():
92
  layers.Dense(len(CLASSES), activation='softmax')
93
  ])
94
 
 
95
  model.compile(
96
  optimizer='adam',
97
  loss='categorical_crossentropy',
98
  metrics=['accuracy']
99
  )
100
 
101
- # Progress bar
102
  progress_bar = st.progress(0)
103
 
104
  for epoch in range(EPOCHS):
@@ -106,8 +114,9 @@ def train_and_save_model():
106
  train_data,
107
  validation_data=val_data,
108
  epochs=1,
109
- verbose=0
110
  )
 
111
  progress_bar.progress((epoch + 1) / EPOCHS)
112
 
113
  model.save(MODEL_PATH)
@@ -118,7 +127,7 @@ def train_and_save_model():
118
 
119
 
120
  # -----------------------------
121
- # LOAD OR TRAIN MODEL
122
  # -----------------------------
123
  @st.cache_resource
124
  def load_ai_model():
@@ -127,13 +136,13 @@ def load_ai_model():
127
  model = load_model(MODEL_PATH)
128
 
129
  if model.output_shape[-1] != len(CLASSES):
130
- st.warning("⚠️ Model output mismatch. Retraining model...")
131
  return train_and_save_model()
132
 
133
  return model
134
 
135
  except Exception:
136
- st.warning("⚠️ Corrupted model file. Retraining...")
137
  return train_and_save_model()
138
 
139
  else:
@@ -143,7 +152,7 @@ def load_ai_model():
143
  model = load_ai_model()
144
 
145
  # -----------------------------
146
- # IMAGE PREPROCESSING
147
  # -----------------------------
148
  def preprocess_image(image):
149
  image = image.convert("RGB")
@@ -160,7 +169,7 @@ def preprocess_image(image):
160
 
161
 
162
  # -----------------------------
163
- # PREDICTION FUNCTION
164
  # -----------------------------
165
  def predict_waste(image):
166
  processed_img = preprocess_image(image)
@@ -169,6 +178,9 @@ def predict_waste(image):
169
 
170
  probabilities = prediction[0]
171
 
 
 
 
172
  predicted_index = np.argmax(probabilities)
173
  predicted_class = CLASSES[predicted_index]
174
  confidence = probabilities[predicted_index] * 100
@@ -180,10 +192,10 @@ def predict_waste(image):
180
  # UI HEADER
181
  # -----------------------------
182
  st.title("♻️ AI Smart Waste Classification")
183
- st.write("Upload an image to classify waste and encourage sustainable recycling.")
184
 
185
  # -----------------------------
186
- # DATASET CHECK
187
  # -----------------------------
188
  with st.sidebar:
189
  st.header("πŸ“‚ Dataset Status")
@@ -191,14 +203,16 @@ with st.sidebar:
191
  if os.path.exists(DATASET_DIR):
192
  st.success("Dataset Found")
193
 
194
- for folder in os.listdir(DATASET_DIR):
 
 
195
  st.write(f"βœ”οΈ {folder}")
196
 
197
  else:
198
  st.error("Dataset Missing")
199
 
200
  # -----------------------------
201
- # FILE UPLOADER
202
  # -----------------------------
203
  uploaded_file = st.file_uploader(
204
  "Upload Waste Image",
@@ -206,7 +220,7 @@ uploaded_file = st.file_uploader(
206
  )
207
 
208
  # -----------------------------
209
- # IMAGE ANALYSIS
210
  # -----------------------------
211
  if uploaded_file is not None:
212
  try:
@@ -226,18 +240,18 @@ if uploaded_file is not None:
226
 
227
  for i, class_name in enumerate(CLASSES):
228
  st.progress(float(probabilities[i]))
229
- st.write(f"{class_name.upper()}: {probabilities[i]*100:.2f}%")
230
 
231
  st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
232
  st.info(f"🎯 Confidence: {confidence:.2f}%")
233
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
234
 
235
- # Sustainability Tip
236
  st.subheader("🌱 Sustainability Suggestion")
237
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
238
 
239
  except UnidentifiedImageError:
240
- st.error("❌ Invalid image format. Upload JPG, JPEG, or PNG.")
241
 
242
  except Exception as e:
243
  st.error(f"❌ Error processing image: {str(e)}")
 
38
  )
39
 
40
  # -----------------------------
41
+ # TRAIN MODEL
42
  # -----------------------------
43
  def train_and_save_model():
 
 
 
44
  if not os.path.exists(DATASET_DIR):
45
  st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
46
  st.stop()
 
52
  validation_split=0.2
53
  )
54
 
55
+ # IMPORTANT FIX:
56
+ # Use categorical labels instead of binary
57
  train_data = datagen.flow_from_directory(
58
  DATASET_DIR,
59
  target_size=IMG_SIZE,
60
  batch_size=BATCH_SIZE,
61
  class_mode='categorical',
62
+ subset='training',
63
+ shuffle=True
64
  )
65
 
66
  val_data = datagen.flow_from_directory(
 
68
  target_size=IMG_SIZE,
69
  batch_size=BATCH_SIZE,
70
  class_mode='categorical',
71
+ subset='validation',
72
+ shuffle=True
73
  )
74
 
75
+ # Verify class count
76
+ if train_data.num_classes != len(CLASSES):
77
+ st.error(
78
+ f"❌ Dataset class mismatch. Expected {len(CLASSES)} classes but found {train_data.num_classes}."
79
+ )
80
+ st.stop()
81
+
82
+ # CNN Model
83
  model = models.Sequential([
84
+ layers.Input(shape=(128, 128, 3)),
85
 
86
+ layers.Conv2D(32, (3, 3), activation='relu'),
87
+ layers.MaxPooling2D(2, 2),
88
 
89
+ layers.Conv2D(64, (3, 3), activation='relu'),
90
+ layers.MaxPooling2D(2, 2),
91
 
92
+ layers.Conv2D(128, (3, 3), activation='relu'),
93
+ layers.MaxPooling2D(2, 2),
94
 
95
  layers.Flatten(),
96
 
 
100
  layers.Dense(len(CLASSES), activation='softmax')
101
  ])
102
 
103
+ # COMPILE FIX:
104
  model.compile(
105
  optimizer='adam',
106
  loss='categorical_crossentropy',
107
  metrics=['accuracy']
108
  )
109
 
 
110
  progress_bar = st.progress(0)
111
 
112
  for epoch in range(EPOCHS):
 
114
  train_data,
115
  validation_data=val_data,
116
  epochs=1,
117
+ verbose=1
118
  )
119
+
120
  progress_bar.progress((epoch + 1) / EPOCHS)
121
 
122
  model.save(MODEL_PATH)
 
127
 
128
 
129
  # -----------------------------
130
+ # LOAD MODEL
131
  # -----------------------------
132
  @st.cache_resource
133
  def load_ai_model():
 
136
  model = load_model(MODEL_PATH)
137
 
138
  if model.output_shape[-1] != len(CLASSES):
139
+ st.warning("⚠️ Model mismatch. Retraining...")
140
  return train_and_save_model()
141
 
142
  return model
143
 
144
  except Exception:
145
+ st.warning("⚠️ Corrupted model. Retraining...")
146
  return train_and_save_model()
147
 
148
  else:
 
152
  model = load_ai_model()
153
 
154
  # -----------------------------
155
+ # PREPROCESS IMAGE
156
  # -----------------------------
157
  def preprocess_image(image):
158
  image = image.convert("RGB")
 
169
 
170
 
171
  # -----------------------------
172
+ # PREDICT
173
  # -----------------------------
174
  def predict_waste(image):
175
  processed_img = preprocess_image(image)
 
178
 
179
  probabilities = prediction[0]
180
 
181
+ if len(probabilities) != len(CLASSES):
182
+ raise ValueError("Prediction output mismatch.")
183
+
184
  predicted_index = np.argmax(probabilities)
185
  predicted_class = CLASSES[predicted_index]
186
  confidence = probabilities[predicted_index] * 100
 
192
  # UI HEADER
193
  # -----------------------------
194
  st.title("♻️ AI Smart Waste Classification")
195
+ st.write("Upload an image to classify waste and support sustainable recycling.")
196
 
197
  # -----------------------------
198
+ # SIDEBAR
199
  # -----------------------------
200
  with st.sidebar:
201
  st.header("πŸ“‚ Dataset Status")
 
203
  if os.path.exists(DATASET_DIR):
204
  st.success("Dataset Found")
205
 
206
+ folders = sorted(os.listdir(DATASET_DIR))
207
+
208
+ for folder in folders:
209
  st.write(f"βœ”οΈ {folder}")
210
 
211
  else:
212
  st.error("Dataset Missing")
213
 
214
  # -----------------------------
215
+ # FILE UPLOAD
216
  # -----------------------------
217
  uploaded_file = st.file_uploader(
218
  "Upload Waste Image",
 
220
  )
221
 
222
  # -----------------------------
223
+ # ANALYSIS
224
  # -----------------------------
225
  if uploaded_file is not None:
226
  try:
 
240
 
241
  for i, class_name in enumerate(CLASSES):
242
  st.progress(float(probabilities[i]))
243
+ st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%")
244
 
245
  st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
246
  st.info(f"🎯 Confidence: {confidence:.2f}%")
247
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
248
 
249
+ # Sustainability tip
250
  st.subheader("🌱 Sustainability Suggestion")
251
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
252
 
253
  except UnidentifiedImageError:
254
+ st.error("❌ Invalid image file. Upload JPG, JPEG, or PNG.")
255
 
256
  except Exception as e:
257
  st.error(f"❌ Error processing image: {str(e)}")