Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| from flask import Flask, request, jsonify, render_template | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| from ultralytics import YOLO | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = Flask(__name__) | |
| # Enable CORS for all routes | |
| CORS(app) | |
| # --- Configuration --- | |
| UPLOAD_FOLDER = 'static/uploads' | |
| MODELS_FOLDER = 'models' # New folder for models | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| # Load model name from .env file, with a fallback default | |
| MODEL_NAME = os.getenv('MODEL_NAME', 'best.pt') | |
| MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_NAME) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(MODELS_FOLDER, exist_ok=True) # Ensure models folder exists | |
| os.makedirs('templates', exist_ok=True) # Ensure templates folder exists | |
| # --- Determine Device and Load YOLO Model --- | |
| # Use CUDA if available, otherwise use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load the model once when the application starts for efficiency. | |
| model = None | |
| try: | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Error: Model file not found at {MODEL_PATH}") | |
| print("Please make sure the model file exists and the MODEL_NAME in your .env file is correct.") | |
| else: | |
| model = YOLO(MODEL_PATH) | |
| model.to(device) # Move model to the selected device | |
| print(f"Successfully loaded model '{MODEL_NAME}' on {device}.") | |
| except Exception as e: | |
| print(f"Error loading YOLO model: {e}") | |
| def allowed_file(filename): | |
| """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def home(): | |
| """Serve the main HTML page.""" | |
| return render_template('index.html') | |
| def predict(): | |
| """ | |
| Endpoint to receive an image, run YOLO classification, and return the single best prediction. | |
| """ | |
| if model is None: | |
| return jsonify({"error": "Model could not be loaded. Please check server logs."}), 500 | |
| # 1. --- File Validation --- | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part in the request"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "No selected file"}), 400 | |
| if not file or not allowed_file(file.filename): | |
| return jsonify({"error": "File type not allowed"}), 400 | |
| # 2. --- Save the File Temporarily --- | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # 3. --- Perform Inference --- | |
| try: | |
| # Run the YOLO model on the uploaded image. The model is already on the correct device. | |
| results = model(filepath) | |
| # 4. --- Process Results to Get ONLY the Top Prediction --- | |
| # Get the first result object from the list | |
| result = results[0] | |
| # Access the probabilities object | |
| probs = result.probs | |
| # Get the index and confidence of the top prediction | |
| top1_index = probs.top1 | |
| top1_confidence = float(probs.top1conf) # Convert tensor to Python float | |
| # Get the class name from the model's 'names' dictionary | |
| class_name = model.names[top1_index] | |
| # Create the final prediction object | |
| prediction = { | |
| "class": class_name, | |
| "confidence": top1_confidence | |
| } | |
| # Return the single prediction object as JSON | |
| return jsonify(prediction) | |
| except Exception as e: | |
| return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500 | |
| finally: | |
| # 5. --- Cleanup --- | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) |