Terence9 commited on
Commit
9df438a
·
verified ·
1 Parent(s): de6d610

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +45 -0
  2. app.py +86 -0
  3. model.py +22 -0
  4. predict.py +32 -0
  5. requirements.txt +10 -0
  6. waste_classifier.pth +3 -0
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Waste Classification Model
2
+
3
+ This is a deep learning model for classifying waste images into two categories: Dry Waste and Wet Waste. The model is built using PyTorch and can be used for automated waste sorting systems.
4
+
5
+ ## Model Details
6
+
7
+ - **Model Type**: Convolutional Neural Network (CNN)
8
+ - **Input**: RGB images
9
+ - **Output**: Binary classification (Dry Waste / Wet Waste)
10
+ - **Framework**: PyTorch
11
+
12
+ ## Usage
13
+
14
+ ```python
15
+ import torch
16
+ from model import WasteClassifier
17
+
18
+ # Load the model
19
+ model = WasteClassifier()
20
+ model.load_state_dict(torch.load('waste_classifier.pth'))
21
+ model.eval()
22
+
23
+ # Make predictions
24
+ def predict(image):
25
+ with torch.no_grad():
26
+ output = model(image)
27
+ prediction = torch.argmax(output, dim=1)
28
+ return "Dry Waste" if prediction.item() == 0 else "Wet Waste"
29
+ ```
30
+
31
+ ## Requirements
32
+
33
+ The model requires the following dependencies:
34
+ - PyTorch
35
+ - torchvision
36
+ - PIL
37
+ - numpy
38
+
39
+ ## Training
40
+
41
+ The model was trained on a custom dataset of waste images. The training notebook (`training.ipynb`) contains the complete training pipeline and data preprocessing steps.
42
+
43
+ ## License
44
+
45
+ This model is released under the MIT License.
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from datetime import datetime
5
+ from flask import Flask, render_template, request, jsonify
6
+ from werkzeug.utils import secure_filename
7
+ from predict import predict_waste
8
+
9
+ app = Flask(__name__)
10
+ app.config['UPLOAD_FOLDER'] = 'uploads'
11
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
12
+ app.config['HISTORY_FILE'] = 'prediction_history.json'
13
+
14
+ # Ensure upload directory exists
15
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
16
+
17
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
18
+
19
+ def allowed_file(filename):
20
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
21
+
22
+ def load_history():
23
+ if os.path.exists(app.config['HISTORY_FILE']):
24
+ try:
25
+ with open(app.config['HISTORY_FILE'], 'r') as f:
26
+ return json.load(f)
27
+ except json.JSONDecodeError:
28
+ return []
29
+ return []
30
+
31
+ def save_history(history):
32
+ with open(app.config['HISTORY_FILE'], 'w') as f:
33
+ json.dump(history, f)
34
+
35
+ @app.route('/')
36
+ def home():
37
+ history = load_history()
38
+ return render_template('index.html', history=history)
39
+
40
+ @app.route('/predict', methods=['POST'])
41
+ def predict():
42
+ if 'file' not in request.files:
43
+ return jsonify({'error': 'No file uploaded'}), 400
44
+
45
+ file = request.files['file']
46
+ if file.filename == '':
47
+ return jsonify({'error': 'No file selected'}), 400
48
+
49
+ if file and allowed_file(file.filename):
50
+ filename = secure_filename(file.filename)
51
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
52
+ file.save(filepath)
53
+
54
+ try:
55
+ # Read the file for history before prediction
56
+ with open(filepath, 'rb') as img_file:
57
+ img_data = base64.b64encode(img_file.read()).decode('utf-8')
58
+
59
+ prediction = predict_waste(filepath)
60
+
61
+ # Save to history
62
+ history = load_history()
63
+ history.append({
64
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
65
+ 'prediction': prediction,
66
+ 'image': img_data
67
+ })
68
+ # Keep only last 10 predictions
69
+ history = history[-10:]
70
+ save_history(history)
71
+
72
+ # Clean up the uploaded file
73
+ os.remove(filepath)
74
+ return jsonify({'prediction': prediction})
75
+ except Exception as e:
76
+ return jsonify({'error': str(e)}), 500
77
+
78
+ return jsonify({'error': 'Invalid file type'}), 400
79
+
80
+ @app.route('/clear-history', methods=['POST'])
81
+ def clear_history():
82
+ save_history([])
83
+ return jsonify({'success': True})
84
+
85
+ if __name__ == '__main__':
86
+ app.run(debug=True)
model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class WasteCNN(nn.Module):
4
+ def __init__(self):
5
+ super(WasteCNN, self).__init__()
6
+ self.conv_layer = nn.Sequential(
7
+ nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
8
+ nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
9
+ nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
10
+ )
11
+ self.fc_layer = nn.Sequential(
12
+ nn.Flatten(),
13
+ nn.Linear(128 * 16 * 16, 128),
14
+ nn.ReLU(),
15
+ nn.Dropout(0.5),
16
+ nn.Linear(128, 2) # 2 classes: dry/wet
17
+ )
18
+
19
+ def forward(self, x):
20
+ x = self.conv_layer(x)
21
+ x = self.fc_layer(x)
22
+ return x
predict.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from model import WasteCNN # Import the model architecture
5
+
6
+ def predict_waste(image_path):
7
+ # Load the model
8
+ model = WasteCNN()
9
+ model.load_state_dict(torch.load('waste_classifier.pth', map_location=torch.device('cpu')))
10
+ model.eval()
11
+
12
+ # Prepare the image
13
+ transform = transforms.Compose([
14
+ transforms.Resize((128, 128)),
15
+ transforms.ToTensor(),
16
+ ])
17
+
18
+ image = Image.open(image_path).convert('RGB')
19
+ image = transform(image).unsqueeze(0) # Add batch dimension
20
+
21
+ # Make prediction
22
+ with torch.no_grad():
23
+ output = model(image)
24
+ _, predicted = torch.max(output, 1)
25
+
26
+ return "Dry Waste" if predicted.item() == 0 else "Wet Waste"
27
+
28
+ if __name__ == "__main__":
29
+ # Example usage
30
+ image_path = input("Enter the path to your waste image: ")
31
+ result = predict_waste(image_path)
32
+ print(f"Prediction: {result}")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision>=0.10.0
3
+ pandas>=1.3.0
4
+ Pillow>=8.3.1
5
+ scikit-learn>=0.24.2
6
+ matplotlib>=3.4.2
7
+ # notebook>=6.4.0
8
+ # ipykernel>=6.0.0
9
+ flask>=2.0.0
10
+ werkzeug>=2.0.0
waste_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7c17fd64e9b2681b18bc091832d9feba72864af4875aff84205c10f11fa155b
3
+ size 17156086