# app.py import os import torch from flask import Flask, request, jsonify, render_template from flask_cors import CORS from werkzeug.utils import secure_filename from ultralytics import YOLO from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() app = Flask(__name__) # Enable CORS for all routes CORS(app) # --- Configuration --- UPLOAD_FOLDER = 'static/uploads' MODELS_FOLDER = 'models' # New folder for models ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Load model name from .env file, with a fallback default MODEL_NAME = os.getenv('MODEL_NAME', 'best.pt') MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_NAME) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(MODELS_FOLDER, exist_ok=True) # Ensure models folder exists os.makedirs('templates', exist_ok=True) # Ensure templates folder exists # --- Determine Device and Load YOLO Model --- # Use CUDA if available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the model once when the application starts for efficiency. model = None try: if not os.path.exists(MODEL_PATH): print(f"Error: Model file not found at {MODEL_PATH}") print("Please make sure the model file exists and the MODEL_NAME in your .env file is correct.") else: model = YOLO(MODEL_PATH) model.to(device) # Move model to the selected device print(f"Successfully loaded model '{MODEL_NAME}' on {device}.") except Exception as e: print(f"Error loading YOLO model: {e}") def allowed_file(filename): """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route('/') def home(): """Serve the main HTML page.""" return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): """ Endpoint to receive an image, run YOLO classification, and return the single best prediction. """ if model is None: return jsonify({"error": "Model could not be loaded. Please check server logs."}), 500 # 1. --- File Validation --- if 'file' not in request.files: return jsonify({"error": "No file part in the request"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No selected file"}), 400 if not file or not allowed_file(file.filename): return jsonify({"error": "File type not allowed"}), 400 # 2. --- Save the File Temporarily --- filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) # 3. --- Perform Inference --- try: # Run the YOLO model on the uploaded image. The model is already on the correct device. results = model(filepath) # 4. --- Process Results to Get ONLY the Top Prediction --- # Get the first result object from the list result = results[0] # Access the probabilities object probs = result.probs # Get the index and confidence of the top prediction top1_index = probs.top1 top1_confidence = float(probs.top1conf) # Convert tensor to Python float # Get the class name from the model's 'names' dictionary class_name = model.names[top1_index] # Create the final prediction object prediction = { "class": class_name, "confidence": top1_confidence } # Return the single prediction object as JSON return jsonify(prediction) except Exception as e: return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500 finally: # 5. --- Cleanup --- if os.path.exists(filepath): os.remove(filepath) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=True)