Muthuraja18 commited on
Commit
f51cb29
Β·
verified Β·
1 Parent(s): 08fe3d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -202
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
4
- from tensorflow.keras.models import Sequential, load_model
5
- from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
6
  import numpy as np
7
  from PIL import Image, UnidentifiedImageError
8
  import os
@@ -10,16 +8,22 @@ import os
10
  # -----------------------------
11
  # CONFIGURATION
12
  # -----------------------------
13
- DATASET_DIR = "dataset-resized"
14
  MODEL_PATH = "waste_classifier.h5"
15
- CLASS_FILE = "classes.npy"
16
-
17
  IMG_SIZE = (128, 128)
18
- BATCH_SIZE = 16
19
- EPOCHS = 5
20
 
 
21
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
22
 
 
 
 
 
 
 
 
 
 
 
23
  # -----------------------------
24
  # PAGE SETTINGS
25
  # -----------------------------
@@ -29,252 +33,150 @@ st.set_page_config(
29
  )
30
 
31
  # -----------------------------
32
- # VALIDATE DATASET
33
  # -----------------------------
34
- def validate_dataset():
35
- missing_classes = []
36
- total_images = 0
37
-
38
- for class_name in CLASSES:
39
- class_path = os.path.join(DATASET_DIR, class_name)
 
 
40
 
41
- if not os.path.exists(class_path):
42
- missing_classes.append(class_name)
43
- continue
44
 
45
- files = [
46
- f for f in os.listdir(class_path)
47
- if f.lower().endswith((".jpg", ".jpeg", ".png"))
48
- ]
 
 
49
 
50
- total_images += len(files)
51
 
52
- if len(files) == 0:
53
- missing_classes.append(class_name)
 
54
 
55
- return missing_classes, total_images
56
 
57
- # -----------------------------
58
- # CLEAN DATASET
59
- # -----------------------------
60
- def clean_dataset(dataset_path):
61
- valid_extensions = (".jpg", ".jpeg", ".png")
62
- removed = 0
63
-
64
- for root, dirs, files in os.walk(dataset_path):
65
- for file in files:
66
- file_path = os.path.join(root, file)
67
-
68
- if not file.lower().endswith(valid_extensions):
69
- try:
70
- os.remove(file_path)
71
- removed += 1
72
- except:
73
- pass
74
- continue
75
-
76
- try:
77
- with Image.open(file_path) as img:
78
- img.verify()
79
- except:
80
- try:
81
- os.remove(file_path)
82
- removed += 1
83
- except:
84
- pass
85
-
86
- return removed
87
 
88
  # -----------------------------
89
- # TRAIN MODEL
90
  # -----------------------------
91
- def train_model():
92
- missing_classes, total_images = validate_dataset()
93
-
94
- if total_images == 0:
95
- st.error("Dataset is empty. Please upload proper waste images.")
96
- st.stop()
97
-
98
- if missing_classes:
99
- st.error(
100
- f"Missing or empty class folders: {', '.join(missing_classes)}"
101
- )
102
- st.stop()
103
-
104
- removed_files = clean_dataset(DATASET_DIR)
105
- st.info(f"Removed {removed_files} corrupted/invalid files.")
106
-
107
- datagen = ImageDataGenerator(
108
- rescale=1./255,
109
- validation_split=0.2,
110
- rotation_range=20,
111
- zoom_range=0.2,
112
- horizontal_flip=True
113
- )
114
-
115
- train_data = datagen.flow_from_directory(
116
- DATASET_DIR,
117
- target_size=IMG_SIZE,
118
- batch_size=BATCH_SIZE,
119
- class_mode='categorical',
120
- subset='training',
121
- shuffle=True,
122
- classes=CLASSES
123
- )
124
-
125
- val_data = datagen.flow_from_directory(
126
- DATASET_DIR,
127
- target_size=IMG_SIZE,
128
- batch_size=BATCH_SIZE,
129
- class_mode='categorical',
130
- subset='validation',
131
- shuffle=False,
132
- classes=CLASSES
133
- )
134
-
135
- # Safety check
136
- if train_data.samples == 0 or val_data.samples == 0:
137
- st.error(
138
- "Dataset loading failed. Ensure each folder contains enough valid images."
139
- )
140
- st.stop()
141
-
142
- model = Sequential([
143
- Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
144
- MaxPooling2D(2,2),
145
-
146
- Conv2D(64, (3,3), activation='relu'),
147
- MaxPooling2D(2,2),
148
-
149
- Conv2D(128, (3,3), activation='relu'),
150
- MaxPooling2D(2,2),
151
-
152
- Flatten(),
153
-
154
- Dense(256, activation='relu'),
155
- Dropout(0.5),
156
 
157
- Dense(len(CLASSES), activation='softmax')
158
- ])
159
 
160
- model.compile(
161
- optimizer='adam',
162
- loss='categorical_crossentropy',
163
- metrics=['accuracy']
164
- )
165
 
166
- with st.spinner("Training AI model... Please wait..."):
167
- model.fit(
168
- train_data,
169
- validation_data=val_data,
170
- epochs=EPOCHS
171
- )
172
 
173
- model.save(MODEL_PATH)
174
- np.save(CLASS_FILE, CLASSES)
175
 
176
- return model
177
 
178
  # -----------------------------
179
- # LOAD OR TRAIN MODEL
180
  # -----------------------------
181
- def load_or_train_model():
182
- if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
183
- st.warning("Training model for first-time use...")
184
- return train_model()
 
185
 
186
- try:
187
- model = load_model(MODEL_PATH)
188
- saved_classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
189
 
190
- if saved_classes != CLASSES or model.output_shape[-1] != len(CLASSES):
191
- st.warning("Old model mismatch detected. Retraining...")
192
- os.remove(MODEL_PATH)
193
- os.remove(CLASS_FILE)
194
- return train_model()
195
 
196
- return model
 
197
 
198
- except:
199
- st.warning("Model corrupted. Retraining...")
200
- return train_model()
 
 
201
 
202
- # -----------------------------
203
- # LOAD MODEL
204
- # -----------------------------
205
- model = load_or_train_model()
206
 
207
  # -----------------------------
208
- # UI
209
  # -----------------------------
210
  st.title("♻️ AI Smart Waste Classification")
211
- st.write("Upload an image to classify waste and support sustainable recycling.")
212
 
 
 
 
213
  uploaded_file = st.file_uploader(
214
  "Upload Waste Image",
215
  type=["jpg", "jpeg", "png"]
216
  )
217
 
 
 
 
218
  if uploaded_file is not None:
219
  try:
220
- image = Image.open(uploaded_file).convert("RGB")
 
221
 
 
222
  st.image(
223
  image,
224
- caption="Uploaded Image",
225
  use_container_width=True
226
  )
227
 
228
- img = image.resize(IMG_SIZE)
229
- img_array = np.array(img) / 255.0
230
- img_array = np.expand_dims(img_array, axis=0)
231
-
232
- with st.spinner("Analyzing waste type..."):
233
- prediction = model.predict(img_array, verbose=0)
234
-
235
- probabilities = prediction.flatten()
236
-
237
- predicted_index = np.argmax(probabilities)
238
- predicted_class = CLASSES[predicted_index]
239
- confidence = probabilities[predicted_index] * 100
240
 
 
 
 
241
  st.subheader("πŸ“Š Prediction Scores")
242
 
243
  for i, class_name in enumerate(CLASSES):
244
- st.write(
245
- f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
246
- )
247
-
248
- st.success(
249
- f"Predicted Type: {predicted_class.upper()}"
250
- )
251
-
252
- st.info(
253
- f"Confidence: {confidence:.2f}%"
254
- )
255
 
256
- tips = {
257
- 'plastic': 'Recycle plastic properly to reduce pollution.',
258
- 'paper': 'Reuse or recycle paper to save trees.',
259
- 'metal': 'Metal can be recycled efficiently.',
260
- 'glass': 'Glass is reusable and recyclable.',
261
- 'trash': 'Dispose responsibly to reduce environmental damage.',
262
- 'cardboard': 'Recycle cardboard to reduce waste.'
263
- }
264
 
 
265
  st.subheader("🌱 Sustainability Suggestion")
266
- st.write(
267
- tips.get(
268
- predicted_class,
269
- "Dispose responsibly."
270
- )
271
- )
272
 
273
  except UnidentifiedImageError:
274
- st.error("Invalid image file uploaded.")
275
 
276
  except Exception as e:
277
- st.error(f"Error processing image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  # -----------------------------
280
  # FOOTER
 
1
  import streamlit as st
2
  import tensorflow as tf
3
+ from tensorflow.keras.models import load_model
 
 
4
  import numpy as np
5
  from PIL import Image, UnidentifiedImageError
6
  import os
 
8
  # -----------------------------
9
  # CONFIGURATION
10
  # -----------------------------
 
11
  MODEL_PATH = "waste_classifier.h5"
 
 
12
  IMG_SIZE = (128, 128)
 
 
13
 
14
+ # Fixed class labels
15
  CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
16
 
17
+ # Sustainability tips
18
+ TIPS = {
19
+ 'plastic': 'Recycle plastic properly to reduce pollution.',
20
+ 'paper': 'Reuse or recycle paper to save trees.',
21
+ 'metal': 'Metal can be recycled efficiently.',
22
+ 'glass': 'Glass is reusable and recyclable.',
23
+ 'trash': 'Dispose responsibly to reduce environmental damage.',
24
+ 'cardboard': 'Recycle cardboard to reduce waste.'
25
+ }
26
+
27
  # -----------------------------
28
  # PAGE SETTINGS
29
  # -----------------------------
 
33
  )
34
 
35
  # -----------------------------
36
+ # LOAD MODEL
37
  # -----------------------------
38
+ @st.cache_resource
39
+ def load_ai_model():
40
+ """
41
+ Load trained TensorFlow model safely
42
+ """
43
+ if not os.path.exists(MODEL_PATH):
44
+ st.error("❌ Model file 'waste_classifier.h5' not found.")
45
+ st.stop()
46
 
47
+ try:
48
+ model = load_model(MODEL_PATH)
 
49
 
50
+ # Validate output classes
51
+ if model.output_shape[-1] != len(CLASSES):
52
+ st.error(
53
+ f"❌ Model output mismatch. Expected {len(CLASSES)} classes, got {model.output_shape[-1]}."
54
+ )
55
+ st.stop()
56
 
57
+ return model
58
 
59
+ except Exception as e:
60
+ st.error(f"❌ Error loading model: {str(e)}")
61
+ st.stop()
62
 
 
63
 
64
+ model = load_ai_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # -----------------------------
67
+ # IMAGE PREPROCESSING FUNCTION
68
  # -----------------------------
69
+ def preprocess_image(image):
70
+ """
71
+ Resize and normalize uploaded image
72
+ """
73
+ image = image.convert("RGB")
74
+ image = image.resize(IMG_SIZE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ img_array = np.array(image, dtype=np.float32) / 255.0
 
77
 
78
+ # Ensure proper shape
79
+ if img_array.shape != (128, 128, 3):
80
+ raise ValueError("Image shape mismatch after preprocessing.")
 
 
81
 
82
+ img_array = np.expand_dims(img_array, axis=0)
 
 
 
 
 
83
 
84
+ return img_array
 
85
 
 
86
 
87
  # -----------------------------
88
+ # PREDICTION FUNCTION
89
  # -----------------------------
90
+ def predict_waste(image):
91
+ """
92
+ Predict waste category
93
+ """
94
+ processed_img = preprocess_image(image)
95
 
96
+ prediction = model.predict(processed_img, verbose=0)
 
 
97
 
98
+ probabilities = prediction[0]
 
 
 
 
99
 
100
+ if len(probabilities) != len(CLASSES):
101
+ raise ValueError("Prediction output size mismatch.")
102
 
103
+ predicted_index = np.argmax(probabilities)
104
+ predicted_class = CLASSES[predicted_index]
105
+ confidence = probabilities[predicted_index] * 100
106
+
107
+ return predicted_class, confidence, probabilities
108
 
 
 
 
 
109
 
110
  # -----------------------------
111
+ # UI HEADER
112
  # -----------------------------
113
  st.title("♻️ AI Smart Waste Classification")
114
+ st.write("Upload an image to classify waste for smart recycling.")
115
 
116
+ # -----------------------------
117
+ # FILE UPLOAD
118
+ # -----------------------------
119
  uploaded_file = st.file_uploader(
120
  "Upload Waste Image",
121
  type=["jpg", "jpeg", "png"]
122
  )
123
 
124
+ # -----------------------------
125
+ # IMAGE PREDICTION
126
+ # -----------------------------
127
  if uploaded_file is not None:
128
  try:
129
+ # Load image
130
+ image = Image.open(uploaded_file)
131
 
132
+ # Display image
133
  st.image(
134
  image,
135
+ caption=f"Uploaded Image: {uploaded_file.name}",
136
  use_container_width=True
137
  )
138
 
139
+ # Predict
140
+ with st.spinner("πŸ” Analyzing waste type..."):
141
+ predicted_class, confidence, probabilities = predict_waste(image)
 
 
 
 
 
 
 
 
 
142
 
143
+ # -----------------------------
144
+ # DISPLAY RESULTS
145
+ # -----------------------------
146
  st.subheader("πŸ“Š Prediction Scores")
147
 
148
  for i, class_name in enumerate(CLASSES):
149
+ st.progress(float(probabilities[i]))
150
+ st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%")
 
 
 
 
 
 
 
 
 
151
 
152
+ st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
153
+ st.info(f"🎯 Confidence: {confidence:.2f}%")
154
+ st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
 
 
 
 
 
155
 
156
+ # Sustainability tip
157
  st.subheader("🌱 Sustainability Suggestion")
158
+ st.write(TIPS.get(predicted_class, "Dispose responsibly."))
 
 
 
 
 
159
 
160
  except UnidentifiedImageError:
161
+ st.error("❌ Invalid image file. Please upload JPG, JPEG, or PNG.")
162
 
163
  except Exception as e:
164
+ st.error(f"❌ Error processing image: {str(e)}")
165
+
166
+ # -----------------------------
167
+ # SAMPLE GUIDE
168
+ # -----------------------------
169
+ st.markdown("---")
170
+ st.subheader("πŸ–ΌοΈ Recommended Test Images")
171
+ st.write("""
172
+ Try uploading:
173
+ - plastic_bottle.jpg
174
+ - glass_bottle.jpg
175
+ - cardboard_box.jpg
176
+ - soda_can.jpg
177
+ - newspaper.jpg
178
+ - trash_bag.jpg
179
+ """)
180
 
181
  # -----------------------------
182
  # FOOTER