Muthuraja18's picture
Update app.py (#1)
87593f0
raw
history blame
3.67 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import numpy as np
from PIL import Image
import os
# -----------------------------
# CONFIGURATION
# -----------------------------
DATASET_DIR = "dataset"
MODEL_PATH = "waste_classifier.h5"
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 5
# -----------------------------
# TRAIN MODEL FUNCTION
# -----------------------------
def train_model():
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2
)
train_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training'
)
val_data = datagen.flow_from_directory(
DATASET_DIR,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation'
)
model = Sequential([
Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
MaxPooling2D(2,2),
Conv2D(64, (3,3), activation='relu'),
MaxPooling2D(2,2),
Flatten(),
Dense(128, activation='relu'),
Dense(train_data.num_classes, activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit(
train_data,
validation_data=val_data,
epochs=EPOCHS
)
model.save(MODEL_PATH)
return model, list(train_data.class_indices.keys())
# -----------------------------
# LOAD MODEL
# -----------------------------
if not os.path.exists(MODEL_PATH):
st.warning("Training model for first-time use. Please wait...")
model, classes = train_model()
else:
model = tf.keras.models.load_model(MODEL_PATH)
classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
# -----------------------------
# STREAMLIT UI
# -----------------------------
st.set_page_config(page_title="AI Waste Classifier")
st.title("♻️ AI Smart Waste Classification")
st.write("Upload an image to classify waste and support sustainable recycling.")
uploaded_file = st.file_uploader(
"Upload Waste Image",
type=["jpg", "jpeg", "png"]
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess image
img = image.resize(IMG_SIZE)
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Predict
prediction = model.predict(img_array)
predicted_class = classes[np.argmax(prediction)]
confidence = np.max(prediction) * 100
# Display Results
st.success(f"Predicted Type: {predicted_class.upper()}")
st.info(f"Confidence: {confidence:.2f}%")
# Sustainability Tips
tips = {
'plastic': 'Recycle plastic properly to reduce pollution.',
'paper': 'Reuse or recycle paper to save trees.',
'metal': 'Metal can be recycled efficiently.',
'glass': 'Glass is reusable and recyclable.',
'trash': 'Dispose responsibly to reduce environmental damage.',
'cardboard': 'Recycle cardboard to reduce waste.'
}
st.subheader("🌱 Sustainability Suggestion")
st.write(tips.get(predicted_class, "Dispose responsibly."))
# -----------------------------
# FOOTER
# -----------------------------
st.markdown("---")
st.caption("Built using TensorFlow + Streamlit for Sustainable AI")