Muthuraja18's picture
Update app.py (#21)
5abf58d
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import numpy as np
from PIL import Image, UnidentifiedImageError
import os
from sklearn.utils.class_weight import compute_class_weight
# -----------------------------
# CONFIGURATION
# -----------------------------
MODEL_PATH = "waste_classifier.h5"
DATASET_DIR = "dataset-resized/dataset-resized"
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 20 # Increased for better accuracy
# Fixed class labels
CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
# Sustainability tips
TIPS = {
'plastic': 'Recycle plastic properly to reduce pollution.',
'paper': 'Reuse or recycle paper to save trees.',
'metal': 'Metal can be recycled efficiently.',
'glass': 'Glass is reusable and recyclable.',
'trash': 'Dispose responsibly to reduce environmental damage.',
'cardboard': 'Recycle cardboard to reduce waste.'
}
# AI Eco Insights
AI_MESSAGES = {
'plastic': "πŸ€– AI Insight: This appears to be plastic waste. Recycling plastic helps reduce pollution and protects oceans.",
'paper': "πŸ€– AI Insight: Paper waste detected. Recycling paper saves trees and reduces landfill burden.",
'metal': "πŸ€– AI Insight: Metal detected. Metal recycling conserves raw materials and energy.",
'glass': "πŸ€– AI Insight: Glass waste identified. Glass is highly recyclable and reusable.",
'trash': "πŸ€– AI Insight: General waste detected. Proper disposal minimizes environmental damage.",
'cardboard': "πŸ€– AI Insight: Cardboard detected. Recycling cardboard supports sustainable packaging."
}
# -----------------------------
# PAGE SETTINGS
# -----------------------------
st.set_page_config(
page_title="AI Smart Waste Classification",
layout="centered"
)
# -----------------------------
# VALIDATE DATASET STRUCTURE
# -----------------------------
def validate_dataset():
if not os.path.exists(DATASET_DIR):
st.error(f"❌ Dataset folder '{DATASET_DIR}' not found.")
st.stop()
found_folders = sorted([
folder for folder in os.listdir(DATASET_DIR)
if os.path.isdir(os.path.join(DATASET_DIR, folder))
])
missing = [cls for cls in CLASSES if cls not in found_folders]
if missing:
st.error("❌ Dataset structure is incorrect.")
st.write("Expected folders:")
for cls in CLASSES:
st.write(f"- {cls}")
st.write("Missing folders:")
for m in missing:
st.write(f"- {m}")
st.stop()
return found_folders
# -----------------------------
# TRAIN MODEL
# -----------------------------
def train_and_save_model():
validate_dataset()
st.info("βš™οΈ Model not found. Training a new model... This may take several minutes.")
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=15,
zoom_range=0.1,
horizontal_flip=True
)
train_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
classes=CLASSES,
class_mode='categorical',
subset='training',
shuffle=True
)
val_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
classes=CLASSES,
class_mode='categorical',
subset='validation',
shuffle=True
)
# -----------------------------
# CLASS WEIGHTS FOR BALANCED TRAINING
# -----------------------------
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(train_data.classes),
y=train_data.classes
)
class_weights = dict(enumerate(class_weights))
# -----------------------------
# CNN MODEL
# -----------------------------
model = models.Sequential([
layers.Input(shape=(128, 128, 3)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D(2, 2),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D(2, 2),
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(128, activation='relu'),
layers.Dropout(0.3),
layers.Dense(len(CLASSES), activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
progress_bar = st.progress(0)
for epoch in range(EPOCHS):
model.fit(
train_data,
validation_data=val_data,
epochs=1,
verbose=1,
class_weight=class_weights
)
progress_bar.progress((epoch + 1) / EPOCHS)
model.save(MODEL_PATH)
st.success("βœ… Model trained and saved successfully!")
return model
# -----------------------------
# LOAD MODEL
# -----------------------------
@st.cache_resource
def load_ai_model():
if os.path.exists(MODEL_PATH):
try:
model = load_model(MODEL_PATH)
if model.output_shape[-1] != len(CLASSES):
st.warning("⚠️ Model mismatch. Retraining...")
return train_and_save_model()
return model
except Exception:
st.warning("⚠️ Corrupted model. Retraining...")
return train_and_save_model()
else:
return train_and_save_model()
model = load_ai_model()
# -----------------------------
# PREPROCESS IMAGE
# -----------------------------
def preprocess_image(image):
image = image.convert("RGB")
image = image.resize(IMG_SIZE)
img_array = np.array(image, dtype=np.float32) / 255.0
if img_array.shape != (128, 128, 3):
raise ValueError("Image preprocessing failed.")
img_array = np.expand_dims(img_array, axis=0)
return img_array
# -----------------------------
# PREDICT
# -----------------------------
def predict_waste(image):
processed_img = preprocess_image(image)
prediction = model.predict(processed_img, verbose=0)
probabilities = prediction[0]
trash_index = CLASSES.index("trash")
# Trash threshold boost
if probabilities[trash_index] > 0.40:
predicted_index = trash_index
else:
predicted_index = np.argmax(probabilities)
predicted_class = CLASSES[predicted_index]
confidence = probabilities[predicted_index] * 100
return predicted_class, confidence, probabilities
# -----------------------------
# UI HEADER
# -----------------------------
st.title("♻️ AI Smart Waste Classification")
st.write("Upload an image to classify waste and support sustainable recycling.")
# -----------------------------
# SIDEBAR
# -----------------------------
with st.sidebar:
st.header("πŸ“‚ Dataset Status")
if os.path.exists(DATASET_DIR):
folders = sorted([
folder for folder in os.listdir(DATASET_DIR)
if os.path.isdir(os.path.join(DATASET_DIR, folder))
])
if folders:
st.success("Dataset Found")
for folder in folders:
st.write(f"βœ”οΈ {folder}")
else:
st.error("No class folders found")
else:
st.error("Dataset Missing")
# -----------------------------
# FILE UPLOADER
# -----------------------------
uploaded_file = st.file_uploader(
"Upload Waste Image",
type=["jpg", "jpeg", "png"]
)
# -----------------------------
# ANALYSIS
# -----------------------------
if uploaded_file is not None:
try:
image = Image.open(uploaded_file)
st.image(
image,
caption=f"Uploaded Image: {uploaded_file.name}",
use_container_width=True
)
with st.spinner("πŸ” Analyzing waste type..."):
predicted_class, confidence, probabilities = predict_waste(image)
# Prediction Scores
st.subheader("πŸ“Š Prediction Scores")
for i, class_name in enumerate(CLASSES):
st.progress(float(probabilities[i]))
st.write(f"{class_name.upper()}: {probabilities[i] * 100:.2f}%")
# Main Output
st.success(f"βœ… Predicted Type: {predicted_class.upper()}")
st.info(f"🎯 Confidence: {confidence:.2f}%")
st.write(f"πŸ“ Uploaded File: {uploaded_file.name}")
# Sustainability Tip
st.subheader("🌱 Sustainability Suggestion")
st.write(TIPS.get(predicted_class, "Dispose responsibly."))
# AI Analysis
st.subheader("πŸ€– AI Environmental Analysis")
st.success(
AI_MESSAGES.get(
predicted_class,
"AI recommends responsible disposal."
)
)
except UnidentifiedImageError:
st.error("❌ Invalid image file. Upload JPG, JPEG, or PNG.")
except Exception as e:
st.error(f"❌ Error processing image: {str(e)}")
# -----------------------------
# SAMPLE GUIDE
# -----------------------------
st.markdown("---")
st.subheader("πŸ–ΌοΈ Recommended Test Images")
st.write("""
Your dataset folder should look like:
dataset-resized/
└── dataset-resized/
β”œβ”€β”€ cardboard/
β”œβ”€β”€ glass/
β”œβ”€β”€ metal/
β”œβ”€β”€ paper/
β”œβ”€β”€ plastic/
└── trash/
""")
# -----------------------------
# FOOTER
# -----------------------------
st.markdown("---")
st.caption("Built using TensorFlow + Streamlit for Sustainable AI")