File size: 4,163 Bytes
80c4760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854a9f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# app.py

import os
import torch
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

app = Flask(__name__)

# Enable CORS for all routes
CORS(app)

# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models' # New folder for models
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

# Load model name from .env file, with a fallback default
MODEL_NAME = os.getenv('MODEL_NAME', 'best.pt')
MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_NAME)

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True) # Ensure models folder exists
os.makedirs('templates', exist_ok=True) # Ensure templates folder exists

# --- Determine Device and Load YOLO Model ---
# Use CUDA if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the model once when the application starts for efficiency.
model = None
try:
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file not found at {MODEL_PATH}")
        print("Please make sure the model file exists and the MODEL_NAME in your .env file is correct.")
    else:
        model = YOLO(MODEL_PATH)
        model.to(device) # Move model to the selected device
        print(f"Successfully loaded model '{MODEL_NAME}' on {device}.")
except Exception as e:
    print(f"Error loading YOLO model: {e}")

def allowed_file(filename):
    """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/')
def home():
    """Serve the main HTML page."""
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    """

    Endpoint to receive an image, run YOLO classification, and return the single best prediction.

    """
    if model is None:
        return jsonify({"error": "Model could not be loaded. Please check server logs."}), 500
        
    # 1. --- File Validation ---
    if 'file' not in request.files:
        return jsonify({"error": "No file part in the request"}), 400

    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400

    if not file or not allowed_file(file.filename):
        return jsonify({"error": "File type not allowed"}), 400

    # 2. --- Save the File Temporarily ---
    filename = secure_filename(file.filename)
    filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(filepath)

    # 3. --- Perform Inference ---
    try:
        # Run the YOLO model on the uploaded image. The model is already on the correct device.
        results = model(filepath)

        # 4. --- Process Results to Get ONLY the Top Prediction ---
        # Get the first result object from the list
        result = results[0]
        
        # Access the probabilities object
        probs = result.probs
        
        # Get the index and confidence of the top prediction
        top1_index = probs.top1
        top1_confidence = float(probs.top1conf) # Convert tensor to Python float
        
        # Get the class name from the model's 'names' dictionary
        class_name = model.names[top1_index]
        
        # Create the final prediction object
        prediction = {
            "class": class_name,
            "confidence": top1_confidence
        }

        # Return the single prediction object as JSON
        return jsonify(prediction)

    except Exception as e:
        return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
    finally:
        # 5. --- Cleanup ---
        if os.path.exists(filepath):
            os.remove(filepath)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)