SakibAhmed's picture
Upload 2 files
3e9ac54 verified
raw
history blame
5.19 kB
# app.py
import os
import torch
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# --- NEW: Load model names from .env file, with fallback defaults ---
MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt')
MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy
MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME)
MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
os.makedirs('templates', exist_ok=True)
# --- Determine Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- NEW: Load multiple YOLO Models ---
model1, model2 = None, None
# Load Model 1
try:
if not os.path.exists(MODEL_1_PATH):
print(f"Warning: Model file not found at {MODEL_1_PATH}")
else:
model1 = YOLO(MODEL_1_PATH)
model1.to(device)
print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}")
# Load Model 2
try:
if not os.path.exists(MODEL_2_PATH):
print(f"Warning: Model file not found at {MODEL_2_PATH}")
else:
model2 = YOLO(MODEL_2_PATH)
model2.to(device)
print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}")
def allowed_file(filename):
"""Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def run_inference(model, filepath):
"""Helper function to run inference and format the result."""
if model is None:
return None # Return None if the model isn't loaded
results = model(filepath)
result = results[0]
probs = result.probs
top1_index = probs.top1
top1_confidence = float(probs.top1conf)
class_name = model.names[top1_index]
return {
"class": class_name,
"confidence": top1_confidence
}
@app.route('/')
def home():
"""Serve the main HTML page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Endpoint to receive an image and run classification based on the requested model type.
"""
# 1. --- File Validation ---
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if not file or not allowed_file(file.filename):
return jsonify({"error": "File type not allowed"}), 400
# --- NEW: Get the model type from the form data ---
model_type = request.form.get('model_type', 'model1') # default to model1
# 2. --- Save the File Temporarily ---
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 3. --- Perform Inference based on model_type ---
try:
if model_type == 'model1':
if model1 is None:
return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500
prediction = run_inference(model1, filepath)
return jsonify(prediction)
elif model_type == 'model2':
if model2 is None:
return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500
prediction = run_inference(model2, filepath)
return jsonify(prediction)
elif model_type == 'combined':
if model1 is None or model2 is None:
return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500
pred1 = run_inference(model1, filepath)
pred2 = run_inference(model2, filepath)
combined_prediction = {
"model1_result": pred1,
"model2_result": pred2
}
return jsonify(combined_prediction)
else:
return jsonify({"error": "Invalid model type specified"}), 400
except Exception as e:
return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
finally:
# 4. --- Cleanup ---
if os.path.exists(filepath):
os.remove(filepath)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)