Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| from flask import Flask, request, jsonify, render_template | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| from ultralytics import YOLO | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = Flask(__name__) | |
| # Enable CORS for all routes | |
| CORS(app) | |
| # --- Configuration --- | |
| UPLOAD_FOLDER = 'static/uploads' | |
| MODELS_FOLDER = 'models' | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| # --- NEW: Load model names from .env file, with fallback defaults --- | |
| MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt') | |
| MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy | |
| MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME) | |
| MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(MODELS_FOLDER, exist_ok=True) | |
| os.makedirs('templates', exist_ok=True) | |
| # --- Determine Device --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # --- NEW: Load multiple YOLO Models --- | |
| model1, model2 = None, None | |
| # Load Model 1 | |
| try: | |
| if not os.path.exists(MODEL_1_PATH): | |
| print(f"Warning: Model file not found at {MODEL_1_PATH}") | |
| else: | |
| model1 = YOLO(MODEL_1_PATH) | |
| model1.to(device) | |
| print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.") | |
| except Exception as e: | |
| print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}") | |
| # Load Model 2 | |
| try: | |
| if not os.path.exists(MODEL_2_PATH): | |
| print(f"Warning: Model file not found at {MODEL_2_PATH}") | |
| else: | |
| model2 = YOLO(MODEL_2_PATH) | |
| model2.to(device) | |
| print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.") | |
| except Exception as e: | |
| print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}") | |
| def allowed_file(filename): | |
| """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def run_inference(model, filepath): | |
| """Helper function to run inference and format the result.""" | |
| if model is None: | |
| return None # Return None if the model isn't loaded | |
| results = model(filepath) | |
| result = results[0] | |
| probs = result.probs | |
| top1_index = probs.top1 | |
| top1_confidence = float(probs.top1conf) | |
| class_name = model.names[top1_index] | |
| return { | |
| "class": class_name, | |
| "confidence": top1_confidence | |
| } | |
| def home(): | |
| """Serve the main HTML page.""" | |
| return render_template('index.html') | |
| def predict(): | |
| """ | |
| Endpoint to receive an image and run classification based on the requested model type. | |
| """ | |
| # 1. --- File Validation --- | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part in the request"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "No selected file"}), 400 | |
| if not file or not allowed_file(file.filename): | |
| return jsonify({"error": "File type not allowed"}), 400 | |
| # --- NEW: Get the model type from the form data --- | |
| model_type = request.form.get('model_type', 'model1') # default to model1 | |
| # 2. --- Save the File Temporarily --- | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # 3. --- Perform Inference based on model_type --- | |
| try: | |
| if model_type == 'model1': | |
| if model1 is None: | |
| return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500 | |
| prediction = run_inference(model1, filepath) | |
| return jsonify(prediction) | |
| elif model_type == 'model2': | |
| if model2 is None: | |
| return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500 | |
| prediction = run_inference(model2, filepath) | |
| return jsonify(prediction) | |
| elif model_type == 'combined': | |
| if model1 is None or model2 is None: | |
| return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500 | |
| pred1 = run_inference(model1, filepath) | |
| pred2 = run_inference(model2, filepath) | |
| combined_prediction = { | |
| "model1_result": pred1, | |
| "model2_result": pred2 | |
| } | |
| return jsonify(combined_prediction) | |
| else: | |
| return jsonify({"error": "Invalid model type specified"}), 400 | |
| except Exception as e: | |
| return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500 | |
| finally: | |
| # 4. --- Cleanup --- | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) |