| | import numpy as np |
| | import streamlit as st |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision.transforms as T |
| | from PIL import Image |
| |
|
| | st.set_page_config(page_title="Garbage Classification") |
| |
|
| |
|
| | |
| | class SimpleCNN(nn.Module): |
| | def __init__(self, num_classes, input_channels=3): |
| | super().__init__() |
| |
|
| | |
| | self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=0) |
| | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=0) |
| | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0) |
| | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=0) |
| | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | self.flatten = nn.Flatten() |
| |
|
| | |
| | self.fc1 = nn.Linear(256 * 12 * 12, 512) |
| | self.dropout1 = nn.Dropout(0.5) |
| |
|
| | self.fc2 = nn.Linear(512, 512) |
| | self.dropout2 = nn.Dropout(0.5) |
| |
|
| | self.fc3 = nn.Linear(512, num_classes) |
| |
|
| | def forward(self, x): |
| | |
| | x = F.relu(self.conv1(x)) |
| | x = self.pool1(x) |
| |
|
| | x = F.relu(self.conv2(x)) |
| | x = self.pool2(x) |
| |
|
| | x = F.relu(self.conv3(x)) |
| | x = self.pool3(x) |
| |
|
| | x = F.relu(self.conv4(x)) |
| | x = self.pool4(x) |
| |
|
| | |
| | x = self.flatten(x) |
| | x = F.relu(self.fc1(x)) |
| | x = self.dropout1(x) |
| |
|
| | x = F.relu(self.fc2(x)) |
| | x = self.dropout2(x) |
| |
|
| | x = self.fc3(x) |
| | return x |
| |
|
| |
|
| | |
| | CLASS_NAMES = [ |
| | "battery", |
| | "biological", |
| | "cardboard", |
| | "clothes", |
| | "glass", |
| | "metal", |
| | "paper", |
| | "plastic", |
| | "shoes", |
| | "trash", |
| | ] |
| |
|
| |
|
| | |
| | @st.cache_resource |
| | def load_model(): |
| | """Load the trained model""" |
| | device = torch.device("cpu") |
| | model = SimpleCNN(num_classes=10) |
| | model = nn.DataParallel(model) |
| |
|
| | try: |
| | model.load_state_dict(torch.load("best_model.pth", map_location=device)) |
| | model.eval() |
| | return model, device |
| | except Exception as e: |
| | st.error(f"Error loading model: {e}") |
| | return None, device |
| |
|
| |
|
| | def preprocess_image(image): |
| | """Preprocess uploaded image""" |
| | transform = T.Compose( |
| | [ |
| | T.Resize(224), |
| | T.CenterCrop(224), |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| |
|
| | image_tensor = transform(image).unsqueeze(0) |
| | return image_tensor |
| |
|
| |
|
| | def predict_image(image, model, device): |
| | """Make prediction on image""" |
| | |
| | input_tensor = preprocess_image(image).to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(input_tensor) |
| | probabilities = F.softmax(outputs, dim=1) |
| | confidence, predicted_idx = torch.max(probabilities, 1) |
| |
|
| | predicted_class = CLASS_NAMES[predicted_idx.item()] |
| | confidence_score = confidence.item() |
| | all_probabilities = probabilities.cpu().numpy().flatten() |
| |
|
| | return predicted_class, confidence_score, all_probabilities |
| |
|
| |
|
| | def get_confidence_color(confidence): |
| | """Get color class based on confidence score""" |
| | if confidence >= 0.7: |
| | return "confidence-high" |
| | elif confidence >= 0.4: |
| | return "confidence-medium" |
| | else: |
| | return "confidence-low" |
| |
|
| |
|
| | def main(): |
| | |
| | model, device = load_model() |
| |
|
| | |
| | st.header("Garbage Classification") |
| | uploaded_file = st.file_uploader( |
| | "Choose an image file", |
| | type=["jpg", "jpeg", "png"], |
| | ) |
| |
|
| | if uploaded_file is not None: |
| | |
| | image = Image.open(uploaded_file).convert("RGB") |
| |
|
| | col1, col2 = st.columns([1, 1]) |
| | with col1: |
| | st.image(image, caption="Uploaded Image", use_container_width=True) |
| |
|
| | |
| | with st.spinner("🔍 Analyzing image..."): |
| | predicted_class, confidence, probabilities = predict_image( |
| | image, model, device |
| | ) |
| |
|
| | sorted_indices = np.argsort(probabilities)[::-1] |
| |
|
| | container = col2.container(border=True) |
| | for i, idx in enumerate(sorted_indices): |
| | class_name = CLASS_NAMES[idx] |
| | prob = probabilities[idx] |
| | container.write(f"{class_name.title()}: {prob:.1%}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|