Muthuraja18's picture
Update app.py
c77fef2 verified
raw
history blame
7.21 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import numpy as np
from PIL import Image, UnidentifiedImageError
import os
# -----------------------------
# CONFIGURATION
# -----------------------------
DATASET_DIR = "dataset-resized"
MODEL_PATH = "waste_classifier.h5"
CLASS_FILE = "classes.npy"
IMG_SIZE = (128, 128)
BATCH_SIZE = 16
EPOCHS = 5
# Fixed class labels
CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
# -----------------------------
# PAGE SETTINGS
# -----------------------------
st.set_page_config(
page_title="AI Smart Waste Classification",
layout="centered"
)
# -----------------------------
# CLEAN DATASET
# -----------------------------
def clean_dataset(dataset_path):
valid_extensions = (".jpg", ".jpeg", ".png")
removed = 0
for root, dirs, files in os.walk(dataset_path):
for file in files:
file_path = os.path.join(root, file)
# Remove invalid extensions
if not file.lower().endswith(valid_extensions):
try:
os.remove(file_path)
removed += 1
except:
pass
continue
# Remove corrupted images
try:
with Image.open(file_path) as img:
img.verify()
except:
try:
os.remove(file_path)
removed += 1
except:
pass
return removed
# -----------------------------
# TRAIN MODEL
# -----------------------------
def train_model():
removed_files = clean_dataset(DATASET_DIR)
st.info(f"Removed {removed_files} corrupted/invalid files.")
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=20,
zoom_range=0.2,
horizontal_flip=True
)
train_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training',
shuffle=True,
classes=CLASSES
)
val_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation',
shuffle=False,
classes=CLASSES
)
model = Sequential([
Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
MaxPooling2D(2,2),
Conv2D(64, (3,3), activation='relu'),
MaxPooling2D(2,2),
Conv2D(128, (3,3), activation='relu'),
MaxPooling2D(2,2),
Flatten(),
Dense(256, activation='relu'),
Dropout(0.5),
Dense(len(CLASSES), activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
with st.spinner("Training AI model... Please wait..."):
model.fit(
train_data,
validation_data=val_data,
epochs=EPOCHS
)
model.save(MODEL_PATH)
np.save(CLASS_FILE, CLASSES)
return model
# -----------------------------
# LOAD OR TRAIN
# -----------------------------
def load_or_train_model():
if not os.path.exists(MODEL_PATH) or not os.path.exists(CLASS_FILE):
st.warning("Training model for first-time use...")
return train_model()
try:
model = load_model(MODEL_PATH)
saved_classes = np.load(CLASS_FILE, allow_pickle=True).tolist()
# Force retrain if mismatch
if saved_classes != CLASSES or model.output_shape[-1] != len(CLASSES):
st.warning("Old model mismatch detected. Retraining...")
os.remove(MODEL_PATH)
os.remove(CLASS_FILE)
return train_model()
return model
except:
st.warning("Model corrupted. Retraining...")
return train_model()
# -----------------------------
# LOAD MODEL
# -----------------------------
model = load_or_train_model()
# -----------------------------
# UI
# -----------------------------
st.title("♻️ AI Smart Waste Classification")
st.write("Upload an image to classify waste and support sustainable recycling.")
uploaded_file = st.file_uploader(
"Upload Waste Image",
type=["jpg", "jpeg", "png"],
accept_multiple_files=False
)
# -----------------------------
# PREDICTION
# -----------------------------
if uploaded_file is not None:
try:
image = Image.open(uploaded_file).convert("RGB")
st.image(
image,
caption="Uploaded Image",
use_container_width=True
)
# Preprocess
img = image.resize(IMG_SIZE)
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Predict
with st.spinner("Analyzing waste type..."):
prediction = model.predict(img_array, verbose=0)
probabilities = prediction.flatten()
predicted_index = np.argmax(probabilities)
predicted_class = CLASSES[predicted_index]
confidence = probabilities[predicted_index] * 100
# -----------------------------
# DISPLAY SCORES
# -----------------------------
st.subheader("📊 Prediction Scores")
for i, class_name in enumerate(CLASSES):
st.write(
f"{class_name.upper()}: {probabilities[i]*100:.2f}%"
)
# Main result
st.success(
f"Predicted Type: {predicted_class.upper()}"
)
st.info(
f"Confidence: {confidence:.2f}%"
)
# Sustainability Tips
tips = {
'plastic': 'Recycle plastic properly to reduce pollution.',
'paper': 'Reuse or recycle paper to save trees.',
'metal': 'Metal can be recycled efficiently.',
'glass': 'Glass is reusable and recyclable.',
'trash': 'Dispose responsibly to reduce environmental damage.',
'cardboard': 'Recycle cardboard to reduce waste.'
}
st.subheader("🌱 Sustainability Suggestion")
st.write(
tips.get(
predicted_class,
"Dispose responsibly."
)
)
except UnidentifiedImageError:
st.error("Invalid image file. Please upload a valid JPG, JPEG, or PNG image.")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# -----------------------------
# SAMPLE TEST IMAGE IDEAS
# -----------------------------
st.markdown("---")
st.subheader("🖼️ Sample Images to Test")
st.write("""
Use images like these:
- plastic_bottle.jpg
- newspaper.jpg
- soda_can.jpg
- glass_bottle.jpg
- cardboard_box.jpg
- trash_bag.jpg
""")
# -----------------------------
# FOOTER
# -----------------------------
st.markdown("---")
st.caption("Built using TensorFlow + Streamlit for Sustainable AI")