Spaces:
Sleeping
Sleeping
File size: 4,163 Bytes
80c4760 854a9f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | # app.py
import os
import torch
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models' # New folder for models
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Load model name from .env file, with a fallback default
MODEL_NAME = os.getenv('MODEL_NAME', 'best.pt')
MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_NAME)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True) # Ensure models folder exists
os.makedirs('templates', exist_ok=True) # Ensure templates folder exists
# --- Determine Device and Load YOLO Model ---
# Use CUDA if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the model once when the application starts for efficiency.
model = None
try:
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file not found at {MODEL_PATH}")
print("Please make sure the model file exists and the MODEL_NAME in your .env file is correct.")
else:
model = YOLO(MODEL_PATH)
model.to(device) # Move model to the selected device
print(f"Successfully loaded model '{MODEL_NAME}' on {device}.")
except Exception as e:
print(f"Error loading YOLO model: {e}")
def allowed_file(filename):
"""Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/')
def home():
"""Serve the main HTML page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Endpoint to receive an image, run YOLO classification, and return the single best prediction.
"""
if model is None:
return jsonify({"error": "Model could not be loaded. Please check server logs."}), 500
# 1. --- File Validation ---
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if not file or not allowed_file(file.filename):
return jsonify({"error": "File type not allowed"}), 400
# 2. --- Save the File Temporarily ---
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 3. --- Perform Inference ---
try:
# Run the YOLO model on the uploaded image. The model is already on the correct device.
results = model(filepath)
# 4. --- Process Results to Get ONLY the Top Prediction ---
# Get the first result object from the list
result = results[0]
# Access the probabilities object
probs = result.probs
# Get the index and confidence of the top prediction
top1_index = probs.top1
top1_confidence = float(probs.top1conf) # Convert tensor to Python float
# Get the class name from the model's 'names' dictionary
class_name = model.names[top1_index]
# Create the final prediction object
prediction = {
"class": class_name,
"confidence": top1_confidence
}
# Return the single prediction object as JSON
return jsonify(prediction)
except Exception as e:
return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
finally:
# 5. --- Cleanup ---
if os.path.exists(filepath):
os.remove(filepath)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True) |