Spaces:
Sleeping
Sleeping
File size: 5,194 Bytes
80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 3e9ac54 80c4760 854a9f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | # app.py
import os
import torch
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# --- NEW: Load model names from .env file, with fallback defaults ---
MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt')
MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy
MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME)
MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
os.makedirs('templates', exist_ok=True)
# --- Determine Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- NEW: Load multiple YOLO Models ---
model1, model2 = None, None
# Load Model 1
try:
if not os.path.exists(MODEL_1_PATH):
print(f"Warning: Model file not found at {MODEL_1_PATH}")
else:
model1 = YOLO(MODEL_1_PATH)
model1.to(device)
print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}")
# Load Model 2
try:
if not os.path.exists(MODEL_2_PATH):
print(f"Warning: Model file not found at {MODEL_2_PATH}")
else:
model2 = YOLO(MODEL_2_PATH)
model2.to(device)
print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}")
def allowed_file(filename):
"""Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def run_inference(model, filepath):
"""Helper function to run inference and format the result."""
if model is None:
return None # Return None if the model isn't loaded
results = model(filepath)
result = results[0]
probs = result.probs
top1_index = probs.top1
top1_confidence = float(probs.top1conf)
class_name = model.names[top1_index]
return {
"class": class_name,
"confidence": top1_confidence
}
@app.route('/')
def home():
"""Serve the main HTML page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Endpoint to receive an image and run classification based on the requested model type.
"""
# 1. --- File Validation ---
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if not file or not allowed_file(file.filename):
return jsonify({"error": "File type not allowed"}), 400
# --- NEW: Get the model type from the form data ---
model_type = request.form.get('model_type', 'model1') # default to model1
# 2. --- Save the File Temporarily ---
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 3. --- Perform Inference based on model_type ---
try:
if model_type == 'model1':
if model1 is None:
return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500
prediction = run_inference(model1, filepath)
return jsonify(prediction)
elif model_type == 'model2':
if model2 is None:
return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500
prediction = run_inference(model2, filepath)
return jsonify(prediction)
elif model_type == 'combined':
if model1 is None or model2 is None:
return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500
pred1 = run_inference(model1, filepath)
pred2 = run_inference(model2, filepath)
combined_prediction = {
"model1_result": pred1,
"model2_result": pred2
}
return jsonify(combined_prediction)
else:
return jsonify({"error": "Invalid model type specified"}), 400
except Exception as e:
return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
finally:
# 4. --- Cleanup ---
if os.path.exists(filepath):
os.remove(filepath)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True) |