from flask import Flask, request, jsonify, render_template import torch import torchvision.transforms as transforms from PIL import Image import torchvision.models as models import io import os app = Flask(__name__) # Load the trained model model_path = "smart_recycling_model1.pth" model = models.resnet50(pretrained=False) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, 6) # 6 categories model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # Define categories categories = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] # Define transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @app.route('/') def home(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify({'error': 'No file uploaded'}), 400 file = request.files['file'] image = Image.open(io.BytesIO(file.read())) image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get confidence scores predicted_idx = torch.argmax(probabilities).item() confidence = probabilities[predicted_idx].item() * 100 category = categories[predicted_idx] response = { 'prediction': category, 'confidence': f"{confidence:.2f}%", 'all_probabilities': {categories[i]: f"{probabilities[i].item() * 100:.2f}%" for i in range(len(categories))}, 'recycling_guidelines': { "cardboard": "Recycle in a dry, clean state. Remove any tape or labels.", "glass": "Rinse and recycle. Avoid broken glass.", "metal": "Rinse and place in the metal recycling bin.", "paper": "Keep dry. Do not include wax-coated paper.", "plastic": "Check recycling code on the item. Rinse before recycling.", "trash": "Non-recyclable. Dispose of responsibly.", }[category] } return jsonify(response) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True)