File size: 6,437 Bytes
ed1abdf
 
ff88581
 
 
 
 
 
 
 
 
 
 
ed1abdf
ff88581
 
 
 
 
 
 
 
 
 
cf0aaae
ff88581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1abdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff88581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from flask import Flask, render_template, request, flash, redirect, url_for, jsonify
from flask_cors import CORS
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import io
import cv2
import os
import tensorflow as tf
import json
import time

app = Flask(__name__)
CORS(app)  # Enable CORS for mobile app
app.secret_key = b'_5#y2L"F4Q8z\n\xec]/'  # Secret key for flash messages

# Define allowed extensions for image uploads
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

# Load the classification labels from a JSON file
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'labels.json'), 'r') as f:
    CLASSIFICATION_LABELS = json.load(f)

# Get the absolute path to the classification model
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models/flagship_model.keras')
# Load the pre-trained classification model
classification_model = load_model(model_path)

def allowed_file(filename):
    """

    Checks if a given filename has an allowed image extension.



    Args:

        filename (str): The name of the file.



    Returns:

        bool: True if the file extension is allowed, False otherwise.

    """
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/')
def index():
    """

    Renders the main index page of the web application.

    """
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    """

    Handles image uploads, preprocesses the image, makes a prediction using the

    classification model, and displays the result.

    """
    # Check if a file was part of the request
    if 'file' not in request.files:
        flash('No file part')
        return redirect(request.url)
    
    file = request.files['file']
    
    # Check if a file was selected
    if file.filename == '':
        flash('No selected file')
        return redirect(request.url)
    
    # Process the file if it exists and is allowed
    if file and allowed_file(file.filename):
        # Read the image file into a BytesIO object
        img = Image.open(io.BytesIO(file.read()))
        img_np = np.array(img)
        
        # Convert RGB image to BGR for OpenCV compatibility (if needed for other operations)
        img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

        # Preprocess the image for the classification model (Expected input: 300x300 for EfficientNetV2B3)
        img_resized_classification = cv2.resize(img_np, (300, 300))  # Resize to model's expected input
        img_reshaped_classification = np.reshape(img_resized_classification, (1, 300, 300, 3)) # Reshape for model input
        
        # EfficientNetV2B3 handles normalization internally (expects 0-255 inputs)
        # So we just pass the resized image directly
        img_preprocessed = img_reshaped_classification

        # Run the classification model to get predictions
        prediction = classification_model.predict(img_preprocessed)
        label_index = np.argmax(prediction)  # Get the index of the highest probability class
        label = CLASSIFICATION_LABELS[label_index]  # Get the corresponding label string

        # Generate a unique filename for the output image using a timestamp
        timestamp = str(int(time.time()))
        output_image_filename = f'output_{timestamp}.jpg'
        # Define the path to save the output image in the static folder
        output_image_path = os.path.join('static', output_image_filename)
        # Save the processed image (original BGR version) to the static folder
        cv2.imwrite(output_image_path, img_bgr)

        # Cleanup old images (older than 1 hour)
        cleanup_old_images()

        # Render the result page with the predicted label and image path
        return render_template('result.html', image_path=output_image_filename, label=label, timestamp=timestamp)
    else:
        # Flash an error message for invalid file types and redirect to the index page
        flash('Invalid file type. Please upload an image (png, jpg, jpeg).')
        return redirect(url_for('index'))

@app.route('/api/predict', methods=['POST'])
def api_predict():
    """

    JSON API endpoint for mobile app predictions.

    Returns: JSON with label and confidence

    """
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    
    if file.filename == '':
        return jsonify({'error': 'No file selected'}), 400
    
    if file and allowed_file(file.filename):
        img = Image.open(io.BytesIO(file.read()))
        img_np = np.array(img)
        
        # Preprocess image for classification model
        img_resized = cv2.resize(img_np, (300, 300))
        img_reshaped = np.reshape(img_resized, (1, 300, 300, 3))
        
        # Run prediction
        prediction = classification_model.predict(img_reshaped)
        label_index = np.argmax(prediction)
        label = CLASSIFICATION_LABELS[label_index]
        confidence = float(prediction[0][label_index])
        
        return jsonify({
            'label': label,
            'confidence': confidence
        })
    else:
        return jsonify({'error': 'Invalid file type'}), 400

def cleanup_old_images(folder='static', age_seconds=3600):
    """

    Removes files in the specified folder that are older than age_seconds.

    """
    try:
        current_time = time.time()
        folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), folder)
        for filename in os.listdir(folder_path):
            if filename.startswith('output_') and filename.endswith('.jpg'):
                file_path = os.path.join(folder_path, filename)
                file_creation_time = os.path.getmtime(file_path)
                if current_time - file_creation_time > age_seconds:
                    os.remove(file_path)
                    print(f"Deleted old image: {filename}")
    except Exception as e:
        print(f"Error cleaning up images: {e}")

if __name__ == '__main__':
    # Get the port from environment variable or use 5000 as default
    port = int(os.environ.get('PORT', 5000))
    # Run the Flask application
    app.run(host='0.0.0.0', port=port, debug=True)