SakibAhmed's picture
Upload 8 files
7616805 verified
import os
import torch
from flask import Flask, request, jsonify, render_template, Response
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv
import time
import json
import traceback
# Import the processing logic
from processing import process_images
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# --- Load model names from .env file ---
PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt')
DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt')
# --- Model Paths ---
PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME)
DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
os.makedirs('templates', exist_ok=True)
# --- Determine Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- Load YOLO Models ---
parts_model, damage_model = None, None
# Load Parts Model
try:
if not os.path.exists(PARTS_MODEL_PATH):
print(f"Warning: Parts model file not found at {PARTS_MODEL_PATH}")
else:
parts_model = YOLO(PARTS_MODEL_PATH)
parts_model.to(device)
print(f"Successfully loaded parts model '{PARTS_MODEL_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Parts Model ({PARTS_MODEL_NAME}): {e}")
# Load Damage Model
try:
if not os.path.exists(DAMAGE_MODEL_PATH):
print(f"Warning: Damage model file not found at {DAMAGE_MODEL_PATH}")
else:
damage_model = YOLO(DAMAGE_MODEL_PATH)
damage_model.to(device)
print(f"Successfully loaded damage model '{DAMAGE_MODEL_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}")
def allowed_file(filename):
"""Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/')
def home():
"""Serve the main HTML page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Endpoint to receive one or more images, process them immediately,
and return the prediction results.
"""
# 1. --- Get Session Key and Validate ---
# Session key can be used for logging or grouping, but doesn't control logic.
session_key = request.form.get('session_key')
if not session_key:
return jsonify({"error": "No session_key provided in the payload"}), 400
# 2. --- File Validation ---
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
files = request.files.getlist('file')
if not files or all(f.filename == '' for f in files):
return jsonify({"error": "No selected files"}), 400
# 3. --- Save Files and Prepare for Processing ---
saved_filepaths = []
for file in files:
if file and allowed_file(file.filename):
# Create a unique filename to prevent overwrites
unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(filepath)
saved_filepaths.append(filepath)
else:
print(f"Skipped invalid file: {file.filename}")
if not saved_filepaths:
return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400
# 4. --- Run Prediction ---
try:
print(f"Processing {len(saved_filepaths)} file(s) for session '{session_key}'...")
# This function processes the images and returns the prediction results.
results = process_images(parts_model, damage_model, saved_filepaths)
print(f"Processing complete for session '{session_key}'.")
# Return the results as a JSON response
return Response(json.dumps(results), mimetype='application/json')
except Exception as e:
print(f"An error occurred during processing for session {session_key}: {e}")
traceback.print_exc()
return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500
finally:
# 5. --- Clean up the saved files ---
for filepath in saved_filepaths:
try:
if os.path.exists(filepath):
os.remove(filepath)
except Exception as e:
print(f"Error cleaning up file {filepath}: {e}")
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)