Files changed (1) hide show
  1. app.py +54 -43
app.py CHANGED
@@ -13,15 +13,19 @@ import os
13
  DATASET_DIR = "dataset-resized"
14
  MODEL_PATH = "waste_classifier.h5"
15
  CLASS_FILE = "classes.npy"
 
16
  IMG_SIZE = (128, 128)
17
  BATCH_SIZE = 16
18
  EPOCHS = 5
19
 
 
 
 
20
  # -----------------------------
21
- # PAGE CONFIG
22
  # -----------------------------
23
  st.set_page_config(
24
- page_title="AI Waste Classifier",
25
  layout="centered"
26
  )
27
 
@@ -36,6 +40,7 @@ def clean_dataset(dataset_path):
36
  for file in files:
37
  file_path = os.path.join(root, file)
38
 
 
39
  if not file.lower().endswith(valid_extensions):
40
  try:
41
  os.remove(file_path)
@@ -44,6 +49,7 @@ def clean_dataset(dataset_path):
44
  pass
45
  continue
46
 
 
47
  try:
48
  with Image.open(file_path) as img:
49
  img.verify()
@@ -77,7 +83,8 @@ def train_model():
77
  batch_size=BATCH_SIZE,
78
  class_mode='categorical',
79
  subset='training',
80
- shuffle=True
 
81
  )
82
 
83
  val_data = datagen.flow_from_directory(
@@ -86,11 +93,10 @@ def train_model():
86
  batch_size=BATCH_SIZE,
87
  class_mode='categorical',
88
  subset='validation',
89
- shuffle=False
 
90
  )
91
 
92
- classes = list(train_data.class_indices.keys())
93
-
94
  model = Sequential([
95
  Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
96
  MaxPooling2D(2,2),
@@ -106,7 +112,7 @@ def train_model():
106
  Dense(256, activation='relu'),
107
  Dropout(0.5),
108
 
109
- Dense(len(classes), activation='softmax')
110
  ])
111
 
112
  model.compile(
@@ -122,34 +128,31 @@ def train_model():
122
  epochs=EPOCHS
123
  )
124
 
125
- # Save model + classes
126
  model.save(MODEL_PATH)
127
- np.save(CLASS_FILE, classes)
128
 
129
- return model, classes
130
 
131
  # -----------------------------
132
- # LOAD MODEL
133
  # -----------------------------
134
  def load_or_train_model():
135
  if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
136
- st.warning("Training model for first-time use. Please wait...")
137
  return train_model()
138
 
139
  try:
140
  model = load_model(MODEL_PATH)
141
- classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
142
-
143
- # Verify output layer
144
- output_classes = model.output_shape[-1]
145
 
146
- if output_classes != len(classes):
147
- st.warning("Old incorrect model detected. Retraining...")
 
148
  os.remove(MODEL_PATH)
149
  os.remove(CLASS_FILE)
150
  return train_model()
151
 
152
- return model, classes
153
 
154
  except:
155
  st.warning("Model corrupted. Retraining...")
@@ -158,19 +161,23 @@ def load_or_train_model():
158
  # -----------------------------
159
  # LOAD MODEL
160
  # -----------------------------
161
- model, classes = load_or_train_model()
162
 
163
  # -----------------------------
164
- # STREAMLIT UI
165
  # -----------------------------
166
  st.title("♻️ AI Smart Waste Classification")
167
  st.write("Upload an image to classify waste and support sustainable recycling.")
168
 
169
  uploaded_file = st.file_uploader(
170
  "Upload Waste Image",
171
- type=["jpg", "jpeg", "png"]
 
172
  )
173
 
 
 
 
174
  if uploaded_file is not None:
175
  try:
176
  image = Image.open(uploaded_file).convert("RGB")
@@ -181,35 +188,32 @@ if uploaded_file is not None:
181
  use_container_width=True
182
  )
183
 
184
- # -----------------------------
185
- # PREPROCESS
186
- # -----------------------------
187
  img = image.resize(IMG_SIZE)
188
  img_array = np.array(img) / 255.0
189
  img_array = np.expand_dims(img_array, axis=0)
190
 
191
- # -----------------------------
192
- # PREDICT
193
- # -----------------------------
194
  with st.spinner("Analyzing waste type..."):
195
  prediction = model.predict(img_array, verbose=0)
196
 
197
  probabilities = prediction.flatten()
198
 
199
  predicted_index = np.argmax(probabilities)
200
- predicted_class = classes[predicted_index]
201
  confidence = probabilities[predicted_index] * 100
202
 
203
  # -----------------------------
204
- # RESULTS
205
  # -----------------------------
206
  st.subheader("📊 Prediction Scores")
207
 
208
- for i, class_name in enumerate(classes):
209
  st.write(
210
  f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
211
  )
212
 
 
213
  st.success(
214
  f"Predicted Type: {predicted_class.upper()}"
215
  )
@@ -218,9 +222,7 @@ if uploaded_file is not None:
218
  f"Confidence: {confidence:.2f}%"
219
  )
220
 
221
- # -----------------------------
222
- # SUSTAINABILITY TIPS
223
- # -----------------------------
224
  tips = {
225
  'plastic': 'Recycle plastic properly to reduce pollution.',
226
  'paper': 'Reuse or recycle paper to save trees.',
@@ -239,19 +241,28 @@ if uploaded_file is not None:
239
  )
240
 
241
  except UnidentifiedImageError:
242
- st.error(
243
- "Invalid image file. Please upload a valid image."
244
- )
245
 
246
  except Exception as e:
247
- st.error(
248
- f"Error processing image: {str(e)}"
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  # -----------------------------
252
  # FOOTER
253
  # -----------------------------
254
  st.markdown("---")
255
- st.caption(
256
- "Built using TensorFlow + Streamlit for Sustainable AI"
257
- )
 
13
  DATASET_DIR = "dataset-resized"
14
  MODEL_PATH = "waste_classifier.h5"
15
  CLASS_FILE = "classes.npy"
16
+
17
  IMG_SIZE = (128, 128)
18
  BATCH_SIZE = 16
19
  EPOCHS = 5
20
 
21
+ # Fixed class labels
22
+ CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
23
+
24
  # -----------------------------
25
+ # PAGE SETTINGS
26
  # -----------------------------
27
  st.set_page_config(
28
+ page_title="AI Smart Waste Classification",
29
  layout="centered"
30
  )
31
 
 
40
  for file in files:
41
  file_path = os.path.join(root, file)
42
 
43
+ # Remove invalid extensions
44
  if not file.lower().endswith(valid_extensions):
45
  try:
46
  os.remove(file_path)
 
49
  pass
50
  continue
51
 
52
+ # Remove corrupted images
53
  try:
54
  with Image.open(file_path) as img:
55
  img.verify()
 
83
  batch_size=BATCH_SIZE,
84
  class_mode='categorical',
85
  subset='training',
86
+ shuffle=True,
87
+ classes=CLASSES
88
  )
89
 
90
  val_data = datagen.flow_from_directory(
 
93
  batch_size=BATCH_SIZE,
94
  class_mode='categorical',
95
  subset='validation',
96
+ shuffle=False,
97
+ classes=CLASSES
98
  )
99
 
 
 
100
  model = Sequential([
101
  Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
102
  MaxPooling2D(2,2),
 
112
  Dense(256, activation='relu'),
113
  Dropout(0.5),
114
 
115
+ Dense(len(CLASSES), activation='softmax')
116
  ])
117
 
118
  model.compile(
 
128
  epochs=EPOCHS
129
  )
130
 
 
131
  model.save(MODEL_PATH)
132
+ np.save(CLASS_FILE, CLASSES)
133
 
134
+ return model
135
 
136
  # -----------------------------
137
+ # LOAD OR TRAIN
138
  # -----------------------------
139
  def load_or_train_model():
140
  if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
141
+ st.warning("Training model for first-time use...")
142
  return train_model()
143
 
144
  try:
145
  model = load_model(MODEL_PATH)
146
+ saved_classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
 
 
 
147
 
148
+ # Force retrain if mismatch
149
+ if saved_classes != CLASSES or model.output_shape[-1] != len(CLASSES):
150
+ st.warning("Old model mismatch detected. Retraining...")
151
  os.remove(MODEL_PATH)
152
  os.remove(CLASS_FILE)
153
  return train_model()
154
 
155
+ return model
156
 
157
  except:
158
  st.warning("Model corrupted. Retraining...")
 
161
  # -----------------------------
162
  # LOAD MODEL
163
  # -----------------------------
164
+ model = load_or_train_model()
165
 
166
  # -----------------------------
167
+ # UI
168
  # -----------------------------
169
  st.title("♻️ AI Smart Waste Classification")
170
  st.write("Upload an image to classify waste and support sustainable recycling.")
171
 
172
  uploaded_file = st.file_uploader(
173
  "Upload Waste Image",
174
+ type=["jpg", "jpeg", "png"],
175
+ accept_multiple_files=False
176
  )
177
 
178
+ # -----------------------------
179
+ # PREDICTION
180
+ # -----------------------------
181
  if uploaded_file is not None:
182
  try:
183
  image = Image.open(uploaded_file).convert("RGB")
 
188
  use_container_width=True
189
  )
190
 
191
+ # Preprocess
 
 
192
  img = image.resize(IMG_SIZE)
193
  img_array = np.array(img) / 255.0
194
  img_array = np.expand_dims(img_array, axis=0)
195
 
196
+ # Predict
 
 
197
  with st.spinner("Analyzing waste type..."):
198
  prediction = model.predict(img_array, verbose=0)
199
 
200
  probabilities = prediction.flatten()
201
 
202
  predicted_index = np.argmax(probabilities)
203
+ predicted_class = CLASSES[predicted_index]
204
  confidence = probabilities[predicted_index] * 100
205
 
206
  # -----------------------------
207
+ # DISPLAY SCORES
208
  # -----------------------------
209
  st.subheader("📊 Prediction Scores")
210
 
211
+ for i, class_name in enumerate(CLASSES):
212
  st.write(
213
  f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
214
  )
215
 
216
+ # Main result
217
  st.success(
218
  f"Predicted Type: {predicted_class.upper()}"
219
  )
 
222
  f"Confidence: {confidence:.2f}%"
223
  )
224
 
225
+ # Sustainability Tips
 
 
226
  tips = {
227
  'plastic': 'Recycle plastic properly to reduce pollution.',
228
  'paper': 'Reuse or recycle paper to save trees.',
 
241
  )
242
 
243
  except UnidentifiedImageError:
244
+ st.error("Invalid image file. Please upload a valid JPG, JPEG, or PNG image.")
 
 
245
 
246
  except Exception as e:
247
+ st.error(f"Error processing image: {str(e)}")
248
+
249
+ # -----------------------------
250
+ # SAMPLE TEST IMAGE IDEAS
251
+ # -----------------------------
252
+ st.markdown("---")
253
+ st.subheader("🖼️ Sample Images to Test")
254
+ st.write("""
255
+ Use images like these:
256
+ - plastic_bottle.jpg
257
+ - newspaper.jpg
258
+ - soda_can.jpg
259
+ - glass_bottle.jpg
260
+ - cardboard_box.jpg
261
+ - trash_bag.jpg
262
+ """)
263
 
264
  # -----------------------------
265
  # FOOTER
266
  # -----------------------------
267
  st.markdown("---")
268
+ st.caption("Built using TensorFlow + Streamlit for Sustainable AI")