Update app.py
Browse files
app.py
CHANGED
|
@@ -38,12 +38,9 @@ st.set_page_config(
|
|
| 38 |
)
|
| 39 |
|
| 40 |
# -----------------------------
|
| 41 |
-
#
|
| 42 |
# -----------------------------
|
| 43 |
def train_and_save_model():
|
| 44 |
-
"""
|
| 45 |
-
Train CNN model if model file doesn't exist
|
| 46 |
-
"""
|
| 47 |
if not os.path.exists(DATASET_DIR):
|
| 48 |
st.error(f"β Dataset folder '{DATASET_DIR}' not found.")
|
| 49 |
st.stop()
|
|
@@ -55,12 +52,15 @@ def train_and_save_model():
|
|
| 55 |
validation_split=0.2
|
| 56 |
)
|
| 57 |
|
|
|
|
|
|
|
| 58 |
train_data = datagen.flow_from_directory(
|
| 59 |
DATASET_DIR,
|
| 60 |
target_size=IMG_SIZE,
|
| 61 |
batch_size=BATCH_SIZE,
|
| 62 |
class_mode='categorical',
|
| 63 |
-
subset='training'
|
|
|
|
| 64 |
)
|
| 65 |
|
| 66 |
val_data = datagen.flow_from_directory(
|
|
@@ -68,21 +68,29 @@ def train_and_save_model():
|
|
| 68 |
target_size=IMG_SIZE,
|
| 69 |
batch_size=BATCH_SIZE,
|
| 70 |
class_mode='categorical',
|
| 71 |
-
subset='validation'
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
model = models.Sequential([
|
| 76 |
-
layers.Input(shape=(128,128,3)),
|
| 77 |
|
| 78 |
-
layers.Conv2D(32, (3,3), activation='relu'),
|
| 79 |
-
layers.MaxPooling2D(2,2),
|
| 80 |
|
| 81 |
-
layers.Conv2D(64, (3,3), activation='relu'),
|
| 82 |
-
layers.MaxPooling2D(2,2),
|
| 83 |
|
| 84 |
-
layers.Conv2D(128, (3,3), activation='relu'),
|
| 85 |
-
layers.MaxPooling2D(2,2),
|
| 86 |
|
| 87 |
layers.Flatten(),
|
| 88 |
|
|
@@ -92,13 +100,13 @@ def train_and_save_model():
|
|
| 92 |
layers.Dense(len(CLASSES), activation='softmax')
|
| 93 |
])
|
| 94 |
|
|
|
|
| 95 |
model.compile(
|
| 96 |
optimizer='adam',
|
| 97 |
loss='categorical_crossentropy',
|
| 98 |
metrics=['accuracy']
|
| 99 |
)
|
| 100 |
|
| 101 |
-
# Progress bar
|
| 102 |
progress_bar = st.progress(0)
|
| 103 |
|
| 104 |
for epoch in range(EPOCHS):
|
|
@@ -106,8 +114,9 @@ def train_and_save_model():
|
|
| 106 |
train_data,
|
| 107 |
validation_data=val_data,
|
| 108 |
epochs=1,
|
| 109 |
-
verbose=
|
| 110 |
)
|
|
|
|
| 111 |
progress_bar.progress((epoch + 1) / EPOCHS)
|
| 112 |
|
| 113 |
model.save(MODEL_PATH)
|
|
@@ -118,7 +127,7 @@ def train_and_save_model():
|
|
| 118 |
|
| 119 |
|
| 120 |
# -----------------------------
|
| 121 |
-
# LOAD
|
| 122 |
# -----------------------------
|
| 123 |
@st.cache_resource
|
| 124 |
def load_ai_model():
|
|
@@ -127,13 +136,13 @@ def load_ai_model():
|
|
| 127 |
model = load_model(MODEL_PATH)
|
| 128 |
|
| 129 |
if model.output_shape[-1] != len(CLASSES):
|
| 130 |
-
st.warning("β οΈ Model
|
| 131 |
return train_and_save_model()
|
| 132 |
|
| 133 |
return model
|
| 134 |
|
| 135 |
except Exception:
|
| 136 |
-
st.warning("β οΈ Corrupted model
|
| 137 |
return train_and_save_model()
|
| 138 |
|
| 139 |
else:
|
|
@@ -143,7 +152,7 @@ def load_ai_model():
|
|
| 143 |
model = load_ai_model()
|
| 144 |
|
| 145 |
# -----------------------------
|
| 146 |
-
#
|
| 147 |
# -----------------------------
|
| 148 |
def preprocess_image(image):
|
| 149 |
image = image.convert("RGB")
|
|
@@ -160,7 +169,7 @@ def preprocess_image(image):
|
|
| 160 |
|
| 161 |
|
| 162 |
# -----------------------------
|
| 163 |
-
#
|
| 164 |
# -----------------------------
|
| 165 |
def predict_waste(image):
|
| 166 |
processed_img = preprocess_image(image)
|
|
@@ -169,6 +178,9 @@ def predict_waste(image):
|
|
| 169 |
|
| 170 |
probabilities = prediction[0]
|
| 171 |
|
|
|
|
|
|
|
|
|
|
| 172 |
predicted_index = np.argmax(probabilities)
|
| 173 |
predicted_class = CLASSES[predicted_index]
|
| 174 |
confidence = probabilities[predicted_index] * 100
|
|
@@ -180,10 +192,10 @@ def predict_waste(image):
|
|
| 180 |
# UI HEADER
|
| 181 |
# -----------------------------
|
| 182 |
st.title("β»οΈ AI Smart Waste Classification")
|
| 183 |
-
st.write("Upload an image to classify waste and
|
| 184 |
|
| 185 |
# -----------------------------
|
| 186 |
-
#
|
| 187 |
# -----------------------------
|
| 188 |
with st.sidebar:
|
| 189 |
st.header("π Dataset Status")
|
|
@@ -191,14 +203,16 @@ with st.sidebar:
|
|
| 191 |
if os.path.exists(DATASET_DIR):
|
| 192 |
st.success("Dataset Found")
|
| 193 |
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
st.write(f"βοΈ {folder}")
|
| 196 |
|
| 197 |
else:
|
| 198 |
st.error("Dataset Missing")
|
| 199 |
|
| 200 |
# -----------------------------
|
| 201 |
-
# FILE
|
| 202 |
# -----------------------------
|
| 203 |
uploaded_file = st.file_uploader(
|
| 204 |
"Upload Waste Image",
|
|
@@ -206,7 +220,7 @@ uploaded_file = st.file_uploader(
|
|
| 206 |
)
|
| 207 |
|
| 208 |
# -----------------------------
|
| 209 |
-
#
|
| 210 |
# -----------------------------
|
| 211 |
if uploaded_file is not None:
|
| 212 |
try:
|
|
@@ -226,18 +240,18 @@ if uploaded_file is not None:
|
|
| 226 |
|
| 227 |
for i, class_name in enumerate(CLASSES):
|
| 228 |
st.progress(float(probabilities[i]))
|
| 229 |
-
st.write(f"{class_name.upper()}: {probabilities[i]*100:.2f}%")
|
| 230 |
|
| 231 |
st.success(f"β
Predicted Type: {predicted_class.upper()}")
|
| 232 |
st.info(f"π― Confidence: {confidence:.2f}%")
|
| 233 |
st.write(f"π Uploaded File: {uploaded_file.name}")
|
| 234 |
|
| 235 |
-
# Sustainability
|
| 236 |
st.subheader("π± Sustainability Suggestion")
|
| 237 |
st.write(TIPS.get(predicted_class, "Dispose responsibly."))
|
| 238 |
|
| 239 |
except UnidentifiedImageError:
|
| 240 |
-
st.error("β Invalid image
|
| 241 |
|
| 242 |
except Exception as e:
|
| 243 |
st.error(f"β Error processing image: {str(e)}")
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
# -----------------------------
|
| 41 |
+
# TRAIN MODEL
|
| 42 |
# -----------------------------
|
| 43 |
def train_and_save_model():
|
|
|
|
|
|
|
|
|
|
| 44 |
if not os.path.exists(DATASET_DIR):
|
| 45 |
st.error(f"β Dataset folder '{DATASET_DIR}' not found.")
|
| 46 |
st.stop()
|
|
|
|
| 52 |
validation_split=0.2
|
| 53 |
)
|
| 54 |
|
| 55 |
+
# IMPORTANT FIX:
|
| 56 |
+
# Use categorical labels instead of binary
|
| 57 |
train_data = datagen.flow_from_directory(
|
| 58 |
DATASET_DIR,
|
| 59 |
target_size=IMG_SIZE,
|
| 60 |
batch_size=BATCH_SIZE,
|
| 61 |
class_mode='categorical',
|
| 62 |
+
subset='training',
|
| 63 |
+
shuffle=True
|
| 64 |
)
|
| 65 |
|
| 66 |
val_data = datagen.flow_from_directory(
|
|
|
|
| 68 |
target_size=IMG_SIZE,
|
| 69 |
batch_size=BATCH_SIZE,
|
| 70 |
class_mode='categorical',
|
| 71 |
+
subset='validation',
|
| 72 |
+
shuffle=True
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Verify class count
|
| 76 |
+
if train_data.num_classes != len(CLASSES):
|
| 77 |
+
st.error(
|
| 78 |
+
f"β Dataset class mismatch. Expected {len(CLASSES)} classes but found {train_data.num_classes}."
|
| 79 |
+
)
|
| 80 |
+
st.stop()
|
| 81 |
+
|
| 82 |
+
# CNN Model
|
| 83 |
model = models.Sequential([
|
| 84 |
+
layers.Input(shape=(128, 128, 3)),
|
| 85 |
|
| 86 |
+
layers.Conv2D(32, (3, 3), activation='relu'),
|
| 87 |
+
layers.MaxPooling2D(2, 2),
|
| 88 |
|
| 89 |
+
layers.Conv2D(64, (3, 3), activation='relu'),
|
| 90 |
+
layers.MaxPooling2D(2, 2),
|
| 91 |
|
| 92 |
+
layers.Conv2D(128, (3, 3), activation='relu'),
|
| 93 |
+
layers.MaxPooling2D(2, 2),
|
| 94 |
|
| 95 |
layers.Flatten(),
|
| 96 |
|
|
|
|
| 100 |
layers.Dense(len(CLASSES), activation='softmax')
|
| 101 |
])
|
| 102 |
|
| 103 |
+
# COMPILE FIX:
|
| 104 |
model.compile(
|
| 105 |
optimizer='adam',
|
| 106 |
loss='categorical_crossentropy',
|
| 107 |
metrics=['accuracy']
|
| 108 |
)
|
| 109 |
|
|
|
|
| 110 |
progress_bar = st.progress(0)
|
| 111 |
|
| 112 |
for epoch in range(EPOCHS):
|
|
|
|
| 114 |
train_data,
|
| 115 |
validation_data=val_data,
|
| 116 |
epochs=1,
|
| 117 |
+
verbose=1
|
| 118 |
)
|
| 119 |
+
|
| 120 |
progress_bar.progress((epoch + 1) / EPOCHS)
|
| 121 |
|
| 122 |
model.save(MODEL_PATH)
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
# -----------------------------
|
| 130 |
+
# LOAD MODEL
|
| 131 |
# -----------------------------
|
| 132 |
@st.cache_resource
|
| 133 |
def load_ai_model():
|
|
|
|
| 136 |
model = load_model(MODEL_PATH)
|
| 137 |
|
| 138 |
if model.output_shape[-1] != len(CLASSES):
|
| 139 |
+
st.warning("β οΈ Model mismatch. Retraining...")
|
| 140 |
return train_and_save_model()
|
| 141 |
|
| 142 |
return model
|
| 143 |
|
| 144 |
except Exception:
|
| 145 |
+
st.warning("β οΈ Corrupted model. Retraining...")
|
| 146 |
return train_and_save_model()
|
| 147 |
|
| 148 |
else:
|
|
|
|
| 152 |
model = load_ai_model()
|
| 153 |
|
| 154 |
# -----------------------------
|
| 155 |
+
# PREPROCESS IMAGE
|
| 156 |
# -----------------------------
|
| 157 |
def preprocess_image(image):
|
| 158 |
image = image.convert("RGB")
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
# -----------------------------
|
| 172 |
+
# PREDICT
|
| 173 |
# -----------------------------
|
| 174 |
def predict_waste(image):
|
| 175 |
processed_img = preprocess_image(image)
|
|
|
|
| 178 |
|
| 179 |
probabilities = prediction[0]
|
| 180 |
|
| 181 |
+
if len(probabilities) != len(CLASSES):
|
| 182 |
+
raise ValueError("Prediction output mismatch.")
|
| 183 |
+
|
| 184 |
predicted_index = np.argmax(probabilities)
|
| 185 |
predicted_class = CLASSES[predicted_index]
|
| 186 |
confidence = probabilities[predicted_index] * 100
|
|
|
|
| 192 |
# UI HEADER
|
| 193 |
# -----------------------------
|
| 194 |
st.title("β»οΈ AI Smart Waste Classification")
|
| 195 |
+
st.write("Upload an image to classify waste and support sustainable recycling.")
|
| 196 |
|
| 197 |
# -----------------------------
|
| 198 |
+
# SIDEBAR
|
| 199 |
# -----------------------------
|
| 200 |
with st.sidebar:
|
| 201 |
st.header("π Dataset Status")
|
|
|
|
| 203 |
if os.path.exists(DATASET_DIR):
|
| 204 |
st.success("Dataset Found")
|
| 205 |
|
| 206 |
+
folders = sorted(os.listdir(DATASET_DIR))
|
| 207 |
+
|
| 208 |
+
for folder in folders:
|
| 209 |
st.write(f"βοΈ {folder}")
|
| 210 |
|
| 211 |
else:
|
| 212 |
st.error("Dataset Missing")
|
| 213 |
|
| 214 |
# -----------------------------
|
| 215 |
+
# FILE UPLOAD
|
| 216 |
# -----------------------------
|
| 217 |
uploaded_file = st.file_uploader(
|
| 218 |
"Upload Waste Image",
|
|
|
|
| 220 |
)
|
| 221 |
|
| 222 |
# -----------------------------
|
| 223 |
+
# ANALYSIS
|
| 224 |
# -----------------------------
|
| 225 |
if uploaded_file is not None:
|
| 226 |
try:
|
|
|
|
| 240 |
|
| 241 |
for i, class_name in enumerate(CLASSES):
|
| 242 |
st.progress(float(probabilities[i]))
|
| 243 |
+
st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%")
|
| 244 |
|
| 245 |
st.success(f"β
Predicted Type: {predicted_class.upper()}")
|
| 246 |
st.info(f"π― Confidence: {confidence:.2f}%")
|
| 247 |
st.write(f"π Uploaded File: {uploaded_file.name}")
|
| 248 |
|
| 249 |
+
# Sustainability tip
|
| 250 |
st.subheader("π± Sustainability Suggestion")
|
| 251 |
st.write(TIPS.get(predicted_class, "Dispose responsibly."))
|
| 252 |
|
| 253 |
except UnidentifiedImageError:
|
| 254 |
+
st.error("β Invalid image file. Upload JPG, JPEG, or PNG.")
|
| 255 |
|
| 256 |
except Exception as e:
|
| 257 |
st.error(f"β Error processing image: {str(e)}")
|