|
|
import os |
|
|
import base64 |
|
|
import tempfile |
|
|
import tensorflow as tf |
|
|
from flask import Flask, request, render_template, redirect |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
MODEL_PATH = 'waste_classifier_final_5.h5' |
|
|
try: |
|
|
model = tf.keras.models.load_model(MODEL_PATH) |
|
|
print("Image classification model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading image model: {e}") |
|
|
exit() |
|
|
|
|
|
CLASS_NAMES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] |
|
|
|
|
|
|
|
|
def preprocess_image(image_path): |
|
|
img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224)) |
|
|
img_array = tf.keras.preprocessing.image.img_to_array(img) |
|
|
img_array = tf.expand_dims(img_array, 0) |
|
|
|
|
|
return tf.keras.applications.efficientnet.preprocess_input(img_array) |
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
|
def index(): |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/predict', methods=['POST']) |
|
|
def predict(): |
|
|
if 'file' not in request.files: |
|
|
return redirect(request.url) |
|
|
file = request.files['file'] |
|
|
if file.filename == '': |
|
|
return redirect(request.url) |
|
|
|
|
|
if file: |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_file: |
|
|
filepath = tmp_file.name |
|
|
file.save(filepath) |
|
|
|
|
|
with open(filepath, "rb") as f: |
|
|
image_data = f.read() |
|
|
|
|
|
encoded_image = base64.b64encode(image_data).decode('utf-8') |
|
|
image_to_display = f"data:image/jpeg;base64,{encoded_image}" |
|
|
|
|
|
preprocessed_image = preprocess_image(filepath) |
|
|
prediction = model.predict(preprocessed_image) |
|
|
|
|
|
predicted_class_index = tf.argmax(prediction[0]).numpy() |
|
|
predicted_class = CLASS_NAMES[predicted_class_index] |
|
|
confidence = tf.reduce_max(prediction[0]).numpy() * 100 |
|
|
|
|
|
os.remove(filepath) |
|
|
|
|
|
return render_template('index.html', |
|
|
prediction=f'Prediction: {predicted_class}', |
|
|
confidence=f'Confidence: {confidence:.2f}%', |
|
|
uploaded_image=image_to_display) |
|
|
|
|
|
return redirect(request.url) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) |