Spaces:
Sleeping
Sleeping
File size: 6,437 Bytes
ed1abdf ff88581 ed1abdf ff88581 cf0aaae ff88581 ed1abdf ff88581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
from flask_cors import CORS
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import io
import cv2
import os
import tensorflow as tf
import json
import time
app = Flask(__name__)
CORS(app) # Enable CORS for mobile app
app.secret_key = b'_5#y2L"F4Q8z\n\xec]/' # Secret key for flash messages
# Define allowed extensions for image uploads
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Load the classification labels from a JSON file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'labels.json'), 'r') as f:
CLASSIFICATION_LABELS = json.load(f)
# Get the absolute path to the classification model
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models/flagship_model.keras')
# Load the pre-trained classification model
classification_model = load_model(model_path)
def allowed_file(filename):
"""
Checks if a given filename has an allowed image extension.
Args:
filename (str): The name of the file.
Returns:
bool: True if the file extension is allowed, False otherwise.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/')
def index():
"""
Renders the main index page of the web application.
"""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Handles image uploads, preprocesses the image, makes a prediction using the
classification model, and displays the result.
"""
# Check if a file was part of the request
if 'file' not in request.files:
flash('No file part')
return redirect(request.url)
file = request.files['file']
# Check if a file was selected
if file.filename == '':
flash('No selected file')
return redirect(request.url)
# Process the file if it exists and is allowed
if file and allowed_file(file.filename):
# Read the image file into a BytesIO object
img = Image.open(io.BytesIO(file.read()))
img_np = np.array(img)
# Convert RGB image to BGR for OpenCV compatibility (if needed for other operations)
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
# Preprocess the image for the classification model (Expected input: 300x300 for EfficientNetV2B3)
img_resized_classification = cv2.resize(img_np, (300, 300)) # Resize to model's expected input
img_reshaped_classification = np.reshape(img_resized_classification, (1, 300, 300, 3)) # Reshape for model input
# EfficientNetV2B3 handles normalization internally (expects 0-255 inputs)
# So we just pass the resized image directly
img_preprocessed = img_reshaped_classification
# Run the classification model to get predictions
prediction = classification_model.predict(img_preprocessed)
label_index = np.argmax(prediction) # Get the index of the highest probability class
label = CLASSIFICATION_LABELS[label_index] # Get the corresponding label string
# Generate a unique filename for the output image using a timestamp
timestamp = str(int(time.time()))
output_image_filename = f'output_{timestamp}.jpg'
# Define the path to save the output image in the static folder
output_image_path = os.path.join('static', output_image_filename)
# Save the processed image (original BGR version) to the static folder
cv2.imwrite(output_image_path, img_bgr)
# Cleanup old images (older than 1 hour)
cleanup_old_images()
# Render the result page with the predicted label and image path
return render_template('result.html', image_path=output_image_filename, label=label, timestamp=timestamp)
else:
# Flash an error message for invalid file types and redirect to the index page
flash('Invalid file type. Please upload an image (png, jpg, jpeg).')
return redirect(url_for('index'))
@app.route('/api/predict', methods=['POST'])
def api_predict():
"""
JSON API endpoint for mobile app predictions.
Returns: JSON with label and confidence
"""
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if file and allowed_file(file.filename):
img = Image.open(io.BytesIO(file.read()))
img_np = np.array(img)
# Preprocess image for classification model
img_resized = cv2.resize(img_np, (300, 300))
img_reshaped = np.reshape(img_resized, (1, 300, 300, 3))
# Run prediction
prediction = classification_model.predict(img_reshaped)
label_index = np.argmax(prediction)
label = CLASSIFICATION_LABELS[label_index]
confidence = float(prediction[0][label_index])
return jsonify({
'label': label,
'confidence': confidence
})
else:
return jsonify({'error': 'Invalid file type'}), 400
def cleanup_old_images(folder='static', age_seconds=3600):
"""
Removes files in the specified folder that are older than age_seconds.
"""
try:
current_time = time.time()
folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), folder)
for filename in os.listdir(folder_path):
if filename.startswith('output_') and filename.endswith('.jpg'):
file_path = os.path.join(folder_path, filename)
file_creation_time = os.path.getmtime(file_path)
if current_time - file_creation_time > age_seconds:
os.remove(file_path)
print(f"Deleted old image: {filename}")
except Exception as e:
print(f"Error cleaning up images: {e}")
if __name__ == '__main__':
# Get the port from environment variable or use 5000 as default
port = int(os.environ.get('PORT', 5000))
# Run the Flask application
app.run(host='0.0.0.0', port=port, debug=True)
|