# app.py import os import torch from flask import Flask, request, jsonify, render_template from flask_cors import CORS from werkzeug.utils import secure_filename from ultralytics import YOLO from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() app = Flask(__name__) # Enable CORS for all routes CORS(app) # --- Configuration --- UPLOAD_FOLDER = 'static/uploads' MODELS_FOLDER = 'models' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # --- NEW: Load model names from .env file, with fallback defaults --- MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt') MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME) MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(MODELS_FOLDER, exist_ok=True) os.makedirs('templates', exist_ok=True) # --- Determine Device --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # --- NEW: Load multiple YOLO Models --- model1, model2 = None, None # Load Model 1 try: if not os.path.exists(MODEL_1_PATH): print(f"Warning: Model file not found at {MODEL_1_PATH}") else: model1 = YOLO(MODEL_1_PATH) model1.to(device) print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.") except Exception as e: print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}") # Load Model 2 try: if not os.path.exists(MODEL_2_PATH): print(f"Warning: Model file not found at {MODEL_2_PATH}") else: model2 = YOLO(MODEL_2_PATH) model2.to(device) print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.") except Exception as e: print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}") def allowed_file(filename): """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def run_inference(model, filepath): """Helper function to run inference and format the result.""" if model is None: return None # Return None if the model isn't loaded results = model(filepath) result = results[0] probs = result.probs top1_index = probs.top1 top1_confidence = float(probs.top1conf) class_name = model.names[top1_index] return { "class": class_name, "confidence": top1_confidence } @app.route('/') def home(): """Serve the main HTML page.""" return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): """ Endpoint to receive an image and run classification based on the requested model type. """ # 1. --- File Validation --- if 'file' not in request.files: return jsonify({"error": "No file part in the request"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No selected file"}), 400 if not file or not allowed_file(file.filename): return jsonify({"error": "File type not allowed"}), 400 # --- NEW: Get the model type from the form data --- model_type = request.form.get('model_type', 'model1') # default to model1 # 2. --- Save the File Temporarily --- filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) # 3. --- Perform Inference based on model_type --- try: if model_type == 'model1': if model1 is None: return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500 prediction = run_inference(model1, filepath) return jsonify(prediction) elif model_type == 'model2': if model2 is None: return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500 prediction = run_inference(model2, filepath) return jsonify(prediction) elif model_type == 'combined': if model1 is None or model2 is None: return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500 pred1 = run_inference(model1, filepath) pred2 = run_inference(model2, filepath) combined_prediction = { "model1_result": pred1, "model2_result": pred2 } return jsonify(combined_prediction) else: return jsonify({"error": "Invalid model type specified"}), 400 except Exception as e: return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500 finally: # 4. --- Cleanup --- if os.path.exists(filepath): os.remove(filepath) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=True)