Files changed (1) hide show
  1. app.py +121 -42
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
 
 
4
  import numpy as np
5
  from PIL import Image, UnidentifiedImageError
6
  import os
@@ -9,7 +11,10 @@ import os
9
  # CONFIGURATION
10
  # -----------------------------
11
  MODEL_PATH = "waste_classifier.h5"
 
12
  IMG_SIZE = (128, 128)
 
 
13
 
14
  # Fixed class labels
15
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
@@ -33,51 +38,121 @@ st.set_page_config(
33
  )
34
 
35
  # -----------------------------
36
- # LOAD MODEL
37
  # -----------------------------
38
- @st.cache_resource
39
- def load_ai_model():
40
  """
41
- Load trained TensorFlow model safely
42
  """
43
- if not os.path.exists(MODEL_PATH):
44
- st.error("❌ Model file 'waste_classifier.h5' not found.")
45
  st.stop()
46
 
47
- try:
48
- model = load_model(MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Validate output classes
51
- if model.output_shape[-1] != len(CLASSES):
52
- st.error(
53
- f"❌ Model output mismatch. Expected {len(CLASSES)} classes, got {model.output_shape[-1]}."
54
- )
55
- st.stop()
56
 
57
- return model
58
 
59
- except Exception as e:
60
- st.error(f"❌ Error loading model: {str(e)}")
61
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  model = load_ai_model()
65
 
66
  # -----------------------------
67
- # IMAGE PREPROCESSING FUNCTION
68
  # -----------------------------
69
  def preprocess_image(image):
70
- """
71
- Resize and normalize uploaded image
72
- """
73
  image = image.convert("RGB")
74
  image = image.resize(IMG_SIZE)
75
 
76
  img_array = np.array(image, dtype=np.float32) / 255.0
77
 
78
- # Ensure proper shape
79
  if img_array.shape != (128, 128, 3):
80
- raise ValueError("Image shape mismatch after preprocessing.")
81
 
82
  img_array = np.expand_dims(img_array, axis=0)
83
 
@@ -88,18 +163,12 @@ def preprocess_image(image):
88
  # PREDICTION FUNCTION
89
  # -----------------------------
90
  def predict_waste(image):
91
- """
92
- Predict waste category
93
- """
94
  processed_img = preprocess_image(image)
95
 
96
  prediction = model.predict(processed_img, verbose=0)
97
 
98
  probabilities = prediction[0]
99
 
100
- if len(probabilities) != len(CLASSES):
101
- raise ValueError("Prediction output size mismatch.")
102
-
103
  predicted_index = np.argmax(probabilities)
104
  predicted_class = CLASSES[predicted_index]
105
  confidence = probabilities[predicted_index] * 100
@@ -111,10 +180,25 @@ def predict_waste(image):
111
  # UI HEADER
112
  # -----------------------------
113
  st.title("♻️ AI Smart Waste Classification")
114
- st.write("Upload an image to classify waste for smart recycling.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  # -----------------------------
117
- # FILE UPLOAD
118
  # -----------------------------
119
  uploaded_file = st.file_uploader(
120
  "Upload Waste Image",
@@ -122,43 +206,38 @@ uploaded_file = st.file_uploader(
122
  )
123
 
124
  # -----------------------------
125
- # IMAGE PREDICTION
126
  # -----------------------------
127
  if uploaded_file is not None:
128
  try:
129
- # Load image
130
  image = Image.open(uploaded_file)
131
 
132
- # Display image
133
  st.image(
134
  image,
135
  caption=f"Uploaded Image: {uploaded_file.name}",
136
  use_container_width=True
137
  )
138
 
139
- # Predict
140
  with st.spinner("πŸ” Analyzing waste type..."):
141
  predicted_class, confidence, probabilities = predict_waste(image)
142
 
143
- # -----------------------------
144
- # DISPLAY RESULTS
145
- # -----------------------------
146
  st.subheader("πŸ“Š Prediction Scores")
147
 
148
  for i, class_name in enumerate(CLASSES):
149
  st.progress(float(probabilities[i]))
150
- st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%")
151
 
152
  st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
153
  st.info(f"🎯 Confidence: {confidence:.2f}%")
154
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
155
 
156
- # Sustainability tip
157
  st.subheader("🌱 Sustainability Suggestion")
158
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
159
 
160
  except UnidentifiedImageError:
161
- st.error("❌ Invalid image file. Please upload JPG, JPEG, or PNG.")
162
 
163
  except Exception as e:
164
  st.error(f"❌ Error processing image: {str(e)}")
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ from tensorflow.keras import layers, models
6
  import numpy as np
7
  from PIL import Image, UnidentifiedImageError
8
  import os
 
11
  # CONFIGURATION
12
  # -----------------------------
13
  MODEL_PATH = "waste_classifier.h5"
14
+ DATASET_DIR = "dataset-resized"
15
  IMG_SIZE = (128, 128)
16
+ BATCH_SIZE = 32
17
+ EPOCHS = 10
18
 
19
  # Fixed class labels
20
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
 
38
  )
39
 
40
  # -----------------------------
41
+ # MODEL TRAINING FUNCTION
42
  # -----------------------------
43
+ def train_and_save_model():
 
44
  """
45
+ Train CNN model if model file doesn't exist
46
  """
47
+ if not os.path.exists(DATASET_DIR):
48
+ st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
49
  st.stop()
50
 
51
+ st.info("βš™οΈ Model not found. Training a new model... This may take several minutes.")
52
+
53
+ datagen = ImageDataGenerator(
54
+ rescale=1./255,
55
+ validation_split=0.2
56
+ )
57
+
58
+ train_data = datagen.flow_from_directory(
59
+ DATASET_DIR,
60
+ target_size=IMG_SIZE,
61
+ batch_size=BATCH_SIZE,
62
+ class_mode='categorical',
63
+ subset='training'
64
+ )
65
+
66
+ val_data = datagen.flow_from_directory(
67
+ DATASET_DIR,
68
+ target_size=IMG_SIZE,
69
+ batch_size=BATCH_SIZE,
70
+ class_mode='categorical',
71
+ subset='validation'
72
+ )
73
+
74
+ # CNN Architecture
75
+ model = models.Sequential([
76
+ layers.Input(shape=(128,128,3)),
77
+
78
+ layers.Conv2D(32, (3,3), activation='relu'),
79
+ layers.MaxPooling2D(2,2),
80
+
81
+ layers.Conv2D(64, (3,3), activation='relu'),
82
+ layers.MaxPooling2D(2,2),
83
+
84
+ layers.Conv2D(128, (3,3), activation='relu'),
85
+ layers.MaxPooling2D(2,2),
86
+
87
+ layers.Flatten(),
88
+
89
+ layers.Dense(128, activation='relu'),
90
+ layers.Dropout(0.5),
91
+
92
+ layers.Dense(len(CLASSES), activation='softmax')
93
+ ])
94
+
95
+ model.compile(
96
+ optimizer='adam',
97
+ loss='categorical_crossentropy',
98
+ metrics=['accuracy']
99
+ )
100
+
101
+ # Progress bar
102
+ progress_bar = st.progress(0)
103
+
104
+ for epoch in range(EPOCHS):
105
+ model.fit(
106
+ train_data,
107
+ validation_data=val_data,
108
+ epochs=1,
109
+ verbose=0
110
+ )
111
+ progress_bar.progress((epoch + 1) / EPOCHS)
112
 
113
+ model.save(MODEL_PATH)
 
 
 
 
 
114
 
115
+ st.success("βœ… Model trained and saved successfully!")
116
 
117
+ return model
118
+
119
+
120
+ # -----------------------------
121
+ # LOAD OR TRAIN MODEL
122
+ # -----------------------------
123
+ @st.cache_resource
124
+ def load_ai_model():
125
+ if os.path.exists(MODEL_PATH):
126
+ try:
127
+ model = load_model(MODEL_PATH)
128
+
129
+ if model.output_shape[-1] != len(CLASSES):
130
+ st.warning("⚠️ Model output mismatch. Retraining model...")
131
+ return train_and_save_model()
132
+
133
+ return model
134
+
135
+ except Exception:
136
+ st.warning("⚠️ Corrupted model file. Retraining...")
137
+ return train_and_save_model()
138
+
139
+ else:
140
+ return train_and_save_model()
141
 
142
 
143
  model = load_ai_model()
144
 
145
  # -----------------------------
146
+ # IMAGE PREPROCESSING
147
  # -----------------------------
148
  def preprocess_image(image):
 
 
 
149
  image = image.convert("RGB")
150
  image = image.resize(IMG_SIZE)
151
 
152
  img_array = np.array(image, dtype=np.float32) / 255.0
153
 
 
154
  if img_array.shape != (128, 128, 3):
155
+ raise ValueError("Image preprocessing failed.")
156
 
157
  img_array = np.expand_dims(img_array, axis=0)
158
 
 
163
  # PREDICTION FUNCTION
164
  # -----------------------------
165
  def predict_waste(image):
 
 
 
166
  processed_img = preprocess_image(image)
167
 
168
  prediction = model.predict(processed_img, verbose=0)
169
 
170
  probabilities = prediction[0]
171
 
 
 
 
172
  predicted_index = np.argmax(probabilities)
173
  predicted_class = CLASSES[predicted_index]
174
  confidence = probabilities[predicted_index] * 100
 
180
  # UI HEADER
181
  # -----------------------------
182
  st.title("♻️ AI Smart Waste Classification")
183
+ st.write("Upload an image to classify waste and encourage sustainable recycling.")
184
+
185
+ # -----------------------------
186
+ # DATASET CHECK
187
+ # -----------------------------
188
+ with st.sidebar:
189
+ st.header("πŸ“‚ Dataset Status")
190
+
191
+ if os.path.exists(DATASET_DIR):
192
+ st.success("Dataset Found")
193
+
194
+ for folder in os.listdir(DATASET_DIR):
195
+ st.write(f"βœ”οΈ {folder}")
196
+
197
+ else:
198
+ st.error("Dataset Missing")
199
 
200
  # -----------------------------
201
+ # FILE UPLOADER
202
  # -----------------------------
203
  uploaded_file = st.file_uploader(
204
  "Upload Waste Image",
 
206
  )
207
 
208
  # -----------------------------
209
+ # IMAGE ANALYSIS
210
  # -----------------------------
211
  if uploaded_file is not None:
212
  try:
 
213
  image = Image.open(uploaded_file)
214
 
 
215
  st.image(
216
  image,
217
  caption=f"Uploaded Image: {uploaded_file.name}",
218
  use_container_width=True
219
  )
220
 
 
221
  with st.spinner("πŸ” Analyzing waste type..."):
222
  predicted_class, confidence, probabilities = predict_waste(image)
223
 
224
+ # Results
 
 
225
  st.subheader("πŸ“Š Prediction Scores")
226
 
227
  for i, class_name in enumerate(CLASSES):
228
  st.progress(float(probabilities[i]))
229
+ st.write(f"{class_name.upper()}: {probabilities[i]*100:.2f}%")
230
 
231
  st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
232
  st.info(f"🎯 Confidence: {confidence:.2f}%")
233
  st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
234
 
235
+ # Sustainability Tip
236
  st.subheader("🌱 Sustainability Suggestion")
237
  st.write(TIPS.get(predicted_class, "Dispose responsibly."))
238
 
239
  except UnidentifiedImageError:
240
+ st.error("❌ Invalid image format. Upload JPG, JPEG, or PNG.")
241
 
242
  except Exception as e:
243
  st.error(f"❌ Error processing image: {str(e)}")