Update app.py

#18
by Muthuraja18 - opened
Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -38,13 +38,44 @@ st.set_page_config(
38
  )
39
 
40
  # -----------------------------
41
- # TRAIN MODEL
42
  # -----------------------------
43
- def train_and_save_model():
 
 
 
44
  if not os.path.exists(DATASET_DIR):
45
  st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
46
  st.stop()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  st.info("βš™οΈ Model not found. Training a new model... This may take several minutes.")
49
 
50
  datagen = ImageDataGenerator(
@@ -52,12 +83,11 @@ def train_and_save_model():
52
  validation_split=0.2
53
  )
54
 
55
- # IMPORTANT FIX:
56
- # Use categorical labels instead of binary
57
  train_data = datagen.flow_from_directory(
58
  DATASET_DIR,
59
  target_size=IMG_SIZE,
60
  batch_size=BATCH_SIZE,
 
61
  class_mode='categorical',
62
  subset='training',
63
  shuffle=True
@@ -67,18 +97,12 @@ def train_and_save_model():
67
  DATASET_DIR,
68
  target_size=IMG_SIZE,
69
  batch_size=BATCH_SIZE,
 
70
  class_mode='categorical',
71
  subset='validation',
72
  shuffle=True
73
  )
74
 
75
- # Verify class count
76
- if train_data.num_classes != len(CLASSES):
77
- st.error(
78
- f"❌ Dataset class mismatch. Expected {len(CLASSES)} classes but found {train_data.num_classes}."
79
- )
80
- st.stop()
81
-
82
  # CNN Model
83
  model = models.Sequential([
84
  layers.Input(shape=(128, 128, 3)),
@@ -100,7 +124,6 @@ def train_and_save_model():
100
  layers.Dense(len(CLASSES), activation='softmax')
101
  ])
102
 
103
- # COMPILE FIX:
104
  model.compile(
105
  optimizer='adam',
106
  loss='categorical_crossentropy',
@@ -178,9 +201,6 @@ def predict_waste(image):
178
 
179
  probabilities = prediction[0]
180
 
181
- if len(probabilities) != len(CLASSES):
182
- raise ValueError("Prediction output mismatch.")
183
-
184
  predicted_index = np.argmax(probabilities)
185
  predicted_class = CLASSES[predicted_index]
186
  confidence = probabilities[predicted_index] * 100
@@ -201,18 +221,25 @@ with st.sidebar:
201
  st.header("πŸ“‚ Dataset Status")
202
 
203
  if os.path.exists(DATASET_DIR):
204
- st.success("Dataset Found")
 
 
 
 
 
 
205
 
206
- folders = sorted(os.listdir(DATASET_DIR))
 
207
 
208
- for folder in folders:
209
- st.write(f"βœ”οΈ {folder}")
210
 
211
  else:
212
  st.error("Dataset Missing")
213
 
214
  # -----------------------------
215
- # FILE UPLOAD
216
  # -----------------------------
217
  uploaded_file = st.file_uploader(
218
  "Upload Waste Image",
@@ -235,7 +262,6 @@ if uploaded_file is not None:
235
  with st.spinner("πŸ” Analyzing waste type..."):
236
  predicted_class, confidence, probabilities = predict_waste(image)
237
 
238
- # Results
239
  st.subheader("πŸ“Š Prediction Scores")
240
 
241
  for i, class_name in enumerate(CLASSES):
@@ -246,7 +272,6 @@ if uploaded_file is not None:
246
  st.info(f"🎯 Confidence: {confidence:.2f}%")
247
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
248
 
249
- # Sustainability tip
250
  st.subheader("🌱 Sustainability Suggestion")
251
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
252
 
@@ -262,13 +287,15 @@ if uploaded_file is not None:
262
  st.markdown("---")
263
  st.subheader("πŸ–ΌοΈ Recommended Test Images")
264
  st.write("""
265
- Try uploading:
266
- - plastic_bottle.jpg
267
- - glass_bottle.jpg
268
- - cardboard_box.jpg
269
- - soda_can.jpg
270
- - newspaper.jpg
271
- - trash_bag.jpg
 
 
272
  """)
273
 
274
  # -----------------------------
 
38
  )
39
 
40
  # -----------------------------
41
+ # VALIDATE DATASET STRUCTURE
42
  # -----------------------------
43
+ def validate_dataset():
44
+ """
45
+ Ensure dataset folder contains 6 required class subfolders
46
+ """
47
  if not os.path.exists(DATASET_DIR):
48
  st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
49
  st.stop()
50
 
51
+ found_folders = sorted([
52
+ folder for folder in os.listdir(DATASET_DIR)
53
+ if os.path.isdir(os.path.join(DATASET_DIR, folder))
54
+ ])
55
+
56
+ missing = [cls for cls in CLASSES if cls not in found_folders]
57
+
58
+ if missing:
59
+ st.error("❌ Dataset structure is incorrect.")
60
+ st.write("Expected folders:")
61
+ for cls in CLASSES:
62
+ st.write(f"- {cls}")
63
+
64
+ st.write("Missing folders:")
65
+ for m in missing:
66
+ st.write(f"- {m}")
67
+
68
+ st.stop()
69
+
70
+ return found_folders
71
+
72
+
73
+ # -----------------------------
74
+ # TRAIN MODEL
75
+ # -----------------------------
76
+ def train_and_save_model():
77
+ validate_dataset()
78
+
79
  st.info("βš™οΈ Model not found. Training a new model... This may take several minutes.")
80
 
81
  datagen = ImageDataGenerator(
 
83
  validation_split=0.2
84
  )
85
 
 
 
86
  train_data = datagen.flow_from_directory(
87
  DATASET_DIR,
88
  target_size=IMG_SIZE,
89
  batch_size=BATCH_SIZE,
90
+ classes=CLASSES, # FORCE CORRECT CLASSES
91
  class_mode='categorical',
92
  subset='training',
93
  shuffle=True
 
97
  DATASET_DIR,
98
  target_size=IMG_SIZE,
99
  batch_size=BATCH_SIZE,
100
+ classes=CLASSES,
101
  class_mode='categorical',
102
  subset='validation',
103
  shuffle=True
104
  )
105
 
 
 
 
 
 
 
 
106
  # CNN Model
107
  model = models.Sequential([
108
  layers.Input(shape=(128, 128, 3)),
 
124
  layers.Dense(len(CLASSES), activation='softmax')
125
  ])
126
 
 
127
  model.compile(
128
  optimizer='adam',
129
  loss='categorical_crossentropy',
 
201
 
202
  probabilities = prediction[0]
203
 
 
 
 
204
  predicted_index = np.argmax(probabilities)
205
  predicted_class = CLASSES[predicted_index]
206
  confidence = probabilities[predicted_index] * 100
 
221
  st.header("πŸ“‚ Dataset Status")
222
 
223
  if os.path.exists(DATASET_DIR):
224
+ folders = sorted([
225
+ folder for folder in os.listdir(DATASET_DIR)
226
+ if os.path.isdir(os.path.join(DATASET_DIR, folder))
227
+ ])
228
+
229
+ if folders:
230
+ st.success("Dataset Found")
231
 
232
+ for folder in folders:
233
+ st.write(f"βœ”οΈ {folder}")
234
 
235
+ else:
236
+ st.error("No class folders found")
237
 
238
  else:
239
  st.error("Dataset Missing")
240
 
241
  # -----------------------------
242
+ # FILE UPLOADER
243
  # -----------------------------
244
  uploaded_file = st.file_uploader(
245
  "Upload Waste Image",
 
262
  with st.spinner("πŸ” Analyzing waste type..."):
263
  predicted_class, confidence, probabilities = predict_waste(image)
264
 
 
265
  st.subheader("πŸ“Š Prediction Scores")
266
 
267
  for i, class_name in enumerate(CLASSES):
 
272
  st.info(f"🎯 Confidence: {confidence:.2f}%")
273
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
274
 
 
275
  st.subheader("🌱 Sustainability Suggestion")
276
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
277
 
 
287
  st.markdown("---")
288
  st.subheader("πŸ–ΌοΈ Recommended Test Images")
289
  st.write("""
290
+ Your dataset folder should look like:
291
+
292
+ dataset-resized/
293
+ β”œβ”€β”€ cardboard/
294
+ β”œβ”€β”€ glass/
295
+ β”œβ”€β”€ metal/
296
+ β”œβ”€β”€ paper/
297
+ β”œβ”€β”€ plastic/
298
+ └── trash/
299
  """)
300
 
301
  # -----------------------------