import streamlit as st import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras import layers, models import numpy as np from PIL import Image, UnidentifiedImageError import os from sklearn.utils.class_weight import compute_class_weight # ----------------------------- # CONFIGURATION # ----------------------------- MODEL_PATH = "waste_classifier.h5" DATASET_DIR = "dataset-resized/dataset-resized" IMG_SIZE = (128, 128) BATCH_SIZE = 32 EPOCHS = 20 # Increased for better accuracy # Fixed class labels CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] # Sustainability tips TIPS = { 'plastic': 'Recycle plastic properly to reduce pollution.', 'paper': 'Reuse or recycle paper to save trees.', 'metal': 'Metal can be recycled efficiently.', 'glass': 'Glass is reusable and recyclable.', 'trash': 'Dispose responsibly to reduce environmental damage.', 'cardboard': 'Recycle cardboard to reduce waste.' } # AI Eco Insights AI_MESSAGES = { 'plastic': "🤖 AI Insight: This appears to be plastic waste. Recycling plastic helps reduce pollution and protects oceans.", 'paper': "🤖 AI Insight: Paper waste detected. Recycling paper saves trees and reduces landfill burden.", 'metal': "🤖 AI Insight: Metal detected. Metal recycling conserves raw materials and energy.", 'glass': "🤖 AI Insight: Glass waste identified. Glass is highly recyclable and reusable.", 'trash': "🤖 AI Insight: General waste detected. Proper disposal minimizes environmental damage.", 'cardboard': "🤖 AI Insight: Cardboard detected. Recycling cardboard supports sustainable packaging." } # ----------------------------- # PAGE SETTINGS # ----------------------------- st.set_page_config( page_title="AI Smart Waste Classification", layout="centered" ) # ----------------------------- # VALIDATE DATASET STRUCTURE # ----------------------------- def validate_dataset(): if not os.path.exists(DATASET_DIR): st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.") st.stop() found_folders = sorted([ folder for folder in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, folder)) ]) missing = [cls for cls in CLASSES if cls not in found_folders] if missing: st.error("❌ Dataset structure is incorrect.") st.write("Expected folders:") for cls in CLASSES: st.write(f"- {cls}") st.write("Missing folders:") for m in missing: st.write(f"- {m}") st.stop() return found_folders # ----------------------------- # TRAIN MODEL # ----------------------------- def train_and_save_model(): validate_dataset() st.info("⚙️ Model not found. Training a new model... This may take several minutes.") datagen = ImageDataGenerator( rescale=1./255, validation_split=0.2, rotation_range=15, zoom_range=0.1, horizontal_flip=True ) train_data = datagen.flow_from_directory( DATASET_DIR, target_size=IMG_SIZE, batch_size=BATCH_SIZE, classes=CLASSES, class_mode='categorical', subset='training', shuffle=True ) val_data = datagen.flow_from_directory( DATASET_DIR, target_size=IMG_SIZE, batch_size=BATCH_SIZE, classes=CLASSES, class_mode='categorical', subset='validation', shuffle=True ) # ----------------------------- # CLASS WEIGHTS FOR BALANCED TRAINING # ----------------------------- class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(train_data.classes), y=train_data.classes ) class_weights = dict(enumerate(class_weights)) # ----------------------------- # CNN MODEL # ----------------------------- model = models.Sequential([ layers.Input(shape=(128, 128, 3)), layers.Conv2D(32, (3, 3), activation='relu'), layers.MaxPooling2D(2, 2), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D(2, 2), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D(2, 2), layers.Flatten(), layers.Dense(256, activation='relu'), layers.Dropout(0.5), layers.Dense(128, activation='relu'), layers.Dropout(0.3), layers.Dense(len(CLASSES), activation='softmax') ]) model.compile( optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'] ) progress_bar = st.progress(0) for epoch in range(EPOCHS): model.fit( train_data, validation_data=val_data, epochs=1, verbose=1, class_weight=class_weights ) progress_bar.progress((epoch + 1) / EPOCHS) model.save(MODEL_PATH) st.success("✅ Model trained and saved successfully!") return model # ----------------------------- # LOAD MODEL # ----------------------------- @st.cache_resource def load_ai_model(): if os.path.exists(MODEL_PATH): try: model = load_model(MODEL_PATH) if model.output_shape[-1] != len(CLASSES): st.warning("⚠️ Model mismatch. Retraining...") return train_and_save_model() return model except Exception: st.warning("⚠️ Corrupted model. Retraining...") return train_and_save_model() else: return train_and_save_model() model = load_ai_model() # ----------------------------- # PREPROCESS IMAGE # ----------------------------- def preprocess_image(image): image = image.convert("RGB") image = image.resize(IMG_SIZE) img_array = np.array(image, dtype=np.float32) / 255.0 if img_array.shape != (128, 128, 3): raise ValueError("Image preprocessing failed.") img_array = np.expand_dims(img_array, axis=0) return img_array # ----------------------------- # PREDICT # ----------------------------- def predict_waste(image): processed_img = preprocess_image(image) prediction = model.predict(processed_img, verbose=0) probabilities = prediction[0] trash_index = CLASSES.index("trash") # Trash threshold boost if probabilities[trash_index] > 0.40: predicted_index = trash_index else: predicted_index = np.argmax(probabilities) predicted_class = CLASSES[predicted_index] confidence = probabilities[predicted_index] * 100 return predicted_class, confidence, probabilities # ----------------------------- # UI HEADER # ----------------------------- st.title("♻️ AI Smart Waste Classification") st.write("Upload an image to classify waste and support sustainable recycling.") # ----------------------------- # SIDEBAR # ----------------------------- with st.sidebar: st.header("📂 Dataset Status") if os.path.exists(DATASET_DIR): folders = sorted([ folder for folder in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, folder)) ]) if folders: st.success("Dataset Found") for folder in folders: st.write(f"✔️ {folder}") else: st.error("No class folders found") else: st.error("Dataset Missing") # ----------------------------- # FILE UPLOADER # ----------------------------- uploaded_file = st.file_uploader( "Upload Waste Image", type=["jpg", "jpeg", "png"] ) # ----------------------------- # ANALYSIS # ----------------------------- if uploaded_file is not None: try: image = Image.open(uploaded_file) st.image( image, caption=f"Uploaded Image: {uploaded_file.name}", use_container_width=True ) with st.spinner("🔍 Analyzing waste type..."): predicted_class, confidence, probabilities = predict_waste(image) # Prediction Scores st.subheader("📊 Prediction Scores") for i, class_name in enumerate(CLASSES): st.progress(float(probabilities[i])) st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%") # Main Output st.success(f"✅ Predicted Type: {predicted_class.upper()}") st.info(f"🎯 Confidence: {confidence:.2f}%") st.write(f"📁 Uploaded File: {uploaded_file.name}") # Sustainability Tip st.subheader("🌱 Sustainability Suggestion") st.write(TIPS.get(predicted_class, "Dispose responsibly.")) # AI Analysis st.subheader("🤖 AI Environmental Analysis") st.success( AI_MESSAGES.get( predicted_class, "AI recommends responsible disposal." ) ) except UnidentifiedImageError: st.error("❌ Invalid image file. Upload JPG, JPEG, or PNG.") except Exception as e: st.error(f"❌ Error processing image: {str(e)}") # ----------------------------- # SAMPLE GUIDE # ----------------------------- st.markdown("---") st.subheader("🖼️ Recommended Test Images") st.write(""" Your dataset folder should look like: dataset-resized/ └── dataset-resized/ ├── cardboard/ ├── glass/ ├── metal/ ├── paper/ ├── plastic/ └── trash/ """) # ----------------------------- # FOOTER # ----------------------------- st.markdown("---") st.caption("Built using TensorFlow + Streamlit for Sustainable AI")