Muthuraja18 commited on
Commit
08fe3d7
·
1 Parent(s): ce1ed1e
Files changed (1) hide show
  1. app.py +48 -33
app.py CHANGED
@@ -18,7 +18,6 @@ IMG_SIZE = (128, 128)
18
  BATCH_SIZE = 16
19
  EPOCHS = 5
20
 
21
- # Fixed class labels
22
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
23
 
24
  # -----------------------------
@@ -29,6 +28,32 @@ st.set_page_config(
29
  layout="centered"
30
  )
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # -----------------------------
33
  # CLEAN DATASET
34
  # -----------------------------
@@ -40,7 +65,6 @@ def clean_dataset(dataset_path):
40
  for file in files:
41
  file_path = os.path.join(root, file)
42
 
43
- # Remove invalid extensions
44
  if not file.lower().endswith(valid_extensions):
45
  try:
46
  os.remove(file_path)
@@ -49,7 +73,6 @@ def clean_dataset(dataset_path):
49
  pass
50
  continue
51
 
52
- # Remove corrupted images
53
  try:
54
  with Image.open(file_path) as img:
55
  img.verify()
@@ -66,6 +89,18 @@ def clean_dataset(dataset_path):
66
  # TRAIN MODEL
67
  # -----------------------------
68
  def train_model():
 
 
 
 
 
 
 
 
 
 
 
 
69
  removed_files = clean_dataset(DATASET_DIR)
70
  st.info(f"Removed {removed_files} corrupted/invalid files.")
71
 
@@ -97,6 +132,13 @@ def train_model():
97
  classes=CLASSES
98
  )
99
 
 
 
 
 
 
 
 
100
  model = Sequential([
101
  Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
102
  MaxPooling2D(2,2),
@@ -134,7 +176,7 @@ def train_model():
134
  return model
135
 
136
  # -----------------------------
137
- # LOAD OR TRAIN
138
  # -----------------------------
139
  def load_or_train_model():
140
  if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
@@ -145,7 +187,6 @@ def load_or_train_model():
145
  model = load_model(MODEL_PATH)
146
  saved_classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
147
 
148
- # Force retrain if mismatch
149
  if saved_classes != CLASSES or model.output_shape[-1] != len(CLASSES):
150
  st.warning("Old model mismatch detected. Retraining...")
151
  os.remove(MODEL_PATH)
@@ -171,13 +212,9 @@ st.write("Upload an image to classify waste and support sustainable recycling.")
171
 
172
  uploaded_file = st.file_uploader(
173
  "Upload Waste Image",
174
- type=["jpg", "jpeg", "png"],
175
- accept_multiple_files=False
176
  )
177
 
178
- # -----------------------------
179
- # PREDICTION
180
- # -----------------------------
181
  if uploaded_file is not None:
182
  try:
183
  image = Image.open(uploaded_file).convert("RGB")
@@ -188,12 +225,10 @@ if uploaded_file is not None:
188
  use_container_width=True
189
  )
190
 
191
- # Preprocess
192
  img = image.resize(IMG_SIZE)
193
  img_array = np.array(img) / 255.0
194
  img_array = np.expand_dims(img_array, axis=0)
195
 
196
- # Predict
197
  with st.spinner("Analyzing waste type..."):
198
  prediction = model.predict(img_array, verbose=0)
199
 
@@ -203,9 +238,6 @@ if uploaded_file is not None:
203
  predicted_class = CLASSES[predicted_index]
204
  confidence = probabilities[predicted_index] * 100
205
 
206
- # -----------------------------
207
- # DISPLAY SCORES
208
- # -----------------------------
209
  st.subheader("📊 Prediction Scores")
210
 
211
  for i, class_name in enumerate(CLASSES):
@@ -213,7 +245,6 @@ if uploaded_file is not None:
213
  f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
214
  )
215
 
216
- # Main result
217
  st.success(
218
  f"Predicted Type: {predicted_class.upper()}"
219
  )
@@ -222,7 +253,6 @@ if uploaded_file is not None:
222
  f"Confidence: {confidence:.2f}%"
223
  )
224
 
225
- # Sustainability Tips
226
  tips = {
227
  'plastic': 'Recycle plastic properly to reduce pollution.',
228
  'paper': 'Reuse or recycle paper to save trees.',
@@ -241,26 +271,11 @@ if uploaded_file is not None:
241
  )
242
 
243
  except UnidentifiedImageError:
244
- st.error("Invalid image file. Please upload a valid JPG, JPEG, or PNG image.")
245
 
246
  except Exception as e:
247
  st.error(f"Error processing image: {str(e)}")
248
 
249
- # -----------------------------
250
- # SAMPLE TEST IMAGE IDEAS
251
- # -----------------------------
252
- st.markdown("---")
253
- st.subheader("🖼️ Sample Images to Test")
254
- st.write("""
255
- Use images like these:
256
- - plastic_bottle.jpg
257
- - newspaper.jpg
258
- - soda_can.jpg
259
- - glass_bottle.jpg
260
- - cardboard_box.jpg
261
- - trash_bag.jpg
262
- """)
263
-
264
  # -----------------------------
265
  # FOOTER
266
  # -----------------------------
 
18
  BATCH_SIZE = 16
19
  EPOCHS = 5
20
 
 
21
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
22
 
23
  # -----------------------------
 
28
  layout="centered"
29
  )
30
 
31
+ # -----------------------------
32
+ # VALIDATE DATASET
33
+ # -----------------------------
34
+ def validate_dataset():
35
+ missing_classes = []
36
+ total_images = 0
37
+
38
+ for class_name in CLASSES:
39
+ class_path = os.path.join(DATASET_DIR, class_name)
40
+
41
+ if not os.path.exists(class_path):
42
+ missing_classes.append(class_name)
43
+ continue
44
+
45
+ files = [
46
+ f for f in os.listdir(class_path)
47
+ if f.lower().endswith((".jpg", ".jpeg", ".png"))
48
+ ]
49
+
50
+ total_images += len(files)
51
+
52
+ if len(files) == 0:
53
+ missing_classes.append(class_name)
54
+
55
+ return missing_classes, total_images
56
+
57
  # -----------------------------
58
  # CLEAN DATASET
59
  # -----------------------------
 
65
  for file in files:
66
  file_path = os.path.join(root, file)
67
 
 
68
  if not file.lower().endswith(valid_extensions):
69
  try:
70
  os.remove(file_path)
 
73
  pass
74
  continue
75
 
 
76
  try:
77
  with Image.open(file_path) as img:
78
  img.verify()
 
89
  # TRAIN MODEL
90
  # -----------------------------
91
  def train_model():
92
+ missing_classes, total_images = validate_dataset()
93
+
94
+ if total_images == 0:
95
+ st.error("Dataset is empty. Please upload proper waste images.")
96
+ st.stop()
97
+
98
+ if missing_classes:
99
+ st.error(
100
+ f"Missing or empty class folders: {', '.join(missing_classes)}"
101
+ )
102
+ st.stop()
103
+
104
  removed_files = clean_dataset(DATASET_DIR)
105
  st.info(f"Removed {removed_files} corrupted/invalid files.")
106
 
 
132
  classes=CLASSES
133
  )
134
 
135
+ # Safety check
136
+ if train_data.samples == 0 or val_data.samples == 0:
137
+ st.error(
138
+ "Dataset loading failed. Ensure each folder contains enough valid images."
139
+ )
140
+ st.stop()
141
+
142
  model = Sequential([
143
  Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
144
  MaxPooling2D(2,2),
 
176
  return model
177
 
178
  # -----------------------------
179
+ # LOAD OR TRAIN MODEL
180
  # -----------------------------
181
  def load_or_train_model():
182
  if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
 
187
  model = load_model(MODEL_PATH)
188
  saved_classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
189
 
 
190
  if saved_classes != CLASSES or model.output_shape[-1] != len(CLASSES):
191
  st.warning("Old model mismatch detected. Retraining...")
192
  os.remove(MODEL_PATH)
 
212
 
213
  uploaded_file = st.file_uploader(
214
  "Upload Waste Image",
215
+ type=["jpg", "jpeg", "png"]
 
216
  )
217
 
 
 
 
218
  if uploaded_file is not None:
219
  try:
220
  image = Image.open(uploaded_file).convert("RGB")
 
225
  use_container_width=True
226
  )
227
 
 
228
  img = image.resize(IMG_SIZE)
229
  img_array = np.array(img) / 255.0
230
  img_array = np.expand_dims(img_array, axis=0)
231
 
 
232
  with st.spinner("Analyzing waste type..."):
233
  prediction = model.predict(img_array, verbose=0)
234
 
 
238
  predicted_class = CLASSES[predicted_index]
239
  confidence = probabilities[predicted_index] * 100
240
 
 
 
 
241
  st.subheader("📊 Prediction Scores")
242
 
243
  for i, class_name in enumerate(CLASSES):
 
245
  f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
246
  )
247
 
 
248
  st.success(
249
  f"Predicted Type: {predicted_class.upper()}"
250
  )
 
253
  f"Confidence: {confidence:.2f}%"
254
  )
255
 
 
256
  tips = {
257
  'plastic': 'Recycle plastic properly to reduce pollution.',
258
  'paper': 'Reuse or recycle paper to save trees.',
 
271
  )
272
 
273
  except UnidentifiedImageError:
274
+ st.error("Invalid image file uploaded.")
275
 
276
  except Exception as e:
277
  st.error(f"Error processing image: {str(e)}")
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # -----------------------------
280
  # FOOTER
281
  # -----------------------------