File size: 5,194 Bytes
80c4760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e9ac54
80c4760
 
3e9ac54
 
 
 
 
 
80c4760
 
 
3e9ac54
 
80c4760
3e9ac54
80c4760
 
 
3e9ac54
 
 
 
80c4760
3e9ac54
 
80c4760
3e9ac54
 
 
80c4760
3e9ac54
 
 
 
 
 
 
 
 
 
 
 
 
80c4760
 
 
 
 
 
3e9ac54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c4760
 
 
 
 
 
 
 
3e9ac54
80c4760
 
 
 
 
 
 
 
 
 
3e9ac54
 
 
80c4760
 
 
 
 
3e9ac54
80c4760
3e9ac54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c4760
 
 
 
3e9ac54
80c4760
 
 
 
854a9f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# app.py

import os
import torch
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

app = Flask(__name__)

# Enable CORS for all routes
CORS(app)

# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

# --- NEW: Load model names from .env file, with fallback defaults ---
MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt')
MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy

MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME)
MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME)

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
os.makedirs('templates', exist_ok=True)

# --- Determine Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- NEW: Load multiple YOLO Models ---
model1, model2 = None, None

# Load Model 1
try:
    if not os.path.exists(MODEL_1_PATH):
        print(f"Warning: Model file not found at {MODEL_1_PATH}")
    else:
        model1 = YOLO(MODEL_1_PATH)
        model1.to(device)
        print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.")
except Exception as e:
    print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}")

# Load Model 2
try:
    if not os.path.exists(MODEL_2_PATH):
        print(f"Warning: Model file not found at {MODEL_2_PATH}")
    else:
        model2 = YOLO(MODEL_2_PATH)
        model2.to(device)
        print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.")
except Exception as e:
    print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}")


def allowed_file(filename):
    """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def run_inference(model, filepath):
    """Helper function to run inference and format the result."""
    if model is None:
        return None # Return None if the model isn't loaded

    results = model(filepath)
    result = results[0]
    probs = result.probs
    top1_index = probs.top1
    top1_confidence = float(probs.top1conf)
    class_name = model.names[top1_index]
    
    return {
        "class": class_name,
        "confidence": top1_confidence
    }

@app.route('/')
def home():
    """Serve the main HTML page."""
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    """

    Endpoint to receive an image and run classification based on the requested model type.

    """
    # 1. --- File Validation ---
    if 'file' not in request.files:
        return jsonify({"error": "No file part in the request"}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    if not file or not allowed_file(file.filename):
        return jsonify({"error": "File type not allowed"}), 400

    # --- NEW: Get the model type from the form data ---
    model_type = request.form.get('model_type', 'model1') # default to model1

    # 2. --- Save the File Temporarily ---
    filename = secure_filename(file.filename)
    filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(filepath)

    # 3. --- Perform Inference based on model_type ---
    try:
        if model_type == 'model1':
            if model1 is None:
                return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500
            prediction = run_inference(model1, filepath)
            return jsonify(prediction)

        elif model_type == 'model2':
            if model2 is None:
                return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500
            prediction = run_inference(model2, filepath)
            return jsonify(prediction)

        elif model_type == 'combined':
            if model1 is None or model2 is None:
                 return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500
            
            pred1 = run_inference(model1, filepath)
            pred2 = run_inference(model2, filepath)
            
            combined_prediction = {
                "model1_result": pred1,
                "model2_result": pred2
            }
            return jsonify(combined_prediction)

        else:
            return jsonify({"error": "Invalid model type specified"}), 400

    except Exception as e:
        return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
    finally:
        # 4. --- Cleanup ---
        if os.path.exists(filepath):
            os.remove(filepath)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)