import os import torch from flask import Flask, request, jsonify, render_template, Response from flask_cors import CORS from werkzeug.utils import secure_filename from ultralytics import YOLO from dotenv import load_dotenv import time import threading import json import traceback # --- NEW: Import database driver --- import psycopg2 import psycopg2.extras # Import the new processing logic from processing import process_images # Load environment variables from .env file load_dotenv() app = Flask(__name__) # Enable CORS for all routes CORS(app) # --- Session Management --- SESSIONS = {} SESSIONS_LOCK = threading.Lock() # --- Configuration --- UPLOAD_FOLDER = 'static/uploads' MODELS_FOLDER = 'models' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # --- Load model names from .env file --- PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt') DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt') # --- NEW: Load Supabase credentials from .env file --- SUPABASE_HOST = os.getenv('SUPABASE_HOST') SUPABASE_PORT = os.getenv('SUPABASE_PORT') SUPABASE_DB = os.getenv('SUPABASE_DB') SUPABASE_USER = os.getenv('SUPABASE_USER') SUPABASE_PASSWORD = os.getenv('SUPABASE_PASSWORD') # --- NEW: Define valid table columns to prevent SQL injection --- VALID_COLUMNS = [ 'alloys', 'dashboard', 'driver_front_side', 'driver_rear_side', 'interior_front', 'passenger_front_side', 'passenger_rear_side', 'service_history', 'tyres' ] PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME) DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(MODELS_FOLDER, exist_ok=True) os.makedirs('templates', exist_ok=True) # --- Determine Device --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # --- Load YOLO Models --- parts_model, damage_model = None, None # Load Parts Model try: if not os.path.exists(PARTS_MODEL_PATH): print(f"Warning: Parts model file not found at {PARTS_MODEL_PATH}") else: parts_model = YOLO(PARTS_MODEL_PATH) parts_model.to(device) print(f"Successfully loaded parts model '{PARTS_MODEL_NAME}' on {device}.") except Exception as e: print(f"Error loading Parts Model ({PARTS_MODEL_NAME}): {e}") # Load Damage Model try: if not os.path.exists(DAMAGE_MODEL_PATH): print(f"Warning: Damage model file not found at {DAMAGE_MODEL_PATH}") else: damage_model = YOLO(DAMAGE_MODEL_PATH) damage_model.to(device) print(f"Successfully loaded damage model '{DAMAGE_MODEL_NAME}' on {device}.") except Exception as e: print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}") # --- NEW: Database Update Logic --- # --- CORRECTED: Database Update Logic --- def update_database_for_session(session_key, results): """ Connects to the Supabase database and updates the user_info table. Args: session_key (str): The session key to identify the row in user_info. results (list): A list of prediction dictionaries from the model. """ conn = None try: # Establish connection conn = psycopg2.connect( host=SUPABASE_HOST, port=SUPABASE_PORT, dbname=SUPABASE_DB, user=SUPABASE_USER, password=SUPABASE_PASSWORD ) # Use a dictionary cursor to access columns by name cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) # 1. Fetch the current state of the row using the correct column 'phone_number' # --- FIX APPLIED HERE --- cur.execute("SELECT * FROM user_info WHERE phone_number = %s", (session_key,)) current_row = cur.fetchone() if not current_row: print(f"Error: No entry found in user_info for phone_number '{session_key}'") return updates_to_make = {} # 2. Determine what needs to be updated based on the results for res in results: part_class = res.get('part_prediction', {}).get('class') damage_status = res.get('damage_prediction', {}).get('class') if part_class not in VALID_COLUMNS: print(f"Warning: Skipping invalid part_class '{part_class}' from prediction.") continue current_status = current_row[part_class] if current_status == 'correct': continue if current_status is None or (current_status == 'incorrect' and damage_status == 'correct'): updates_to_make[part_class] = damage_status # 3. If there are updates, build and execute a single UPDATE statement if updates_to_make: set_clauses = ", ".join([f"{col} = %s" for col in updates_to_make.keys()]) update_values = list(updates_to_make.values()) update_values.append(session_key) # --- FIX APPLIED HERE --- update_query = f"UPDATE user_info SET {set_clauses} WHERE phone_number = %s" print(f"Executing DB Update for session '{session_key}': {updates_to_make}") cur.execute(update_query, tuple(update_values)) conn.commit() else: print(f"No database updates required for session '{session_key}'.") cur.close() except (Exception, psycopg2.DatabaseError) as error: print(f"Database Error for session '{session_key}': {error}") traceback.print_exc() finally: if conn is not None: conn.close() def allowed_file(filename): """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route('/') def home(): """Serve the main HTML page.""" return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): """ Endpoint to receive one or more images under a session key. The first request for a session waits 10 seconds to aggregate images from subsequent requests, then processes them all. Subsequent requests for an active session add their images and return a JSON status. """ # 1. --- Get Session Key and Validate --- session_key = request.form.get('session_key') if not session_key: return jsonify({"error": "No session_key provided in the payload"}), 400 # 2. --- File Validation --- if 'file' not in request.files: return jsonify({"error": "No file part in the request"}), 400 files = request.files.getlist('file') if not files or all(f.filename == '' for f in files): return jsonify({"error": "No selected files"}), 400 # 3. --- Session Handling --- is_first_request = False with SESSIONS_LOCK: if session_key not in SESSIONS: is_first_request = True SESSIONS[session_key] = { "files": [], "lock": threading.Lock(), "processed": False } session = SESSIONS[session_key] if session["processed"]: return jsonify({"status": "complete", "message": "This session has already been processed."}) # 4. --- Save Files for Current Request --- saved_filepaths_this_request = [] for file in files: if file and allowed_file(file.filename): unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}" filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(filepath) saved_filepaths_this_request.append(filepath) else: print(f"Skipped invalid file: {file.filename}") if not saved_filepaths_this_request: return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400 with session["lock"]: if session["processed"]: for filepath in saved_filepaths_this_request: if os.path.exists(filepath): os.remove(filepath) return jsonify({"status": "complete", "message": "This session has already been processed."}) session["files"].extend(saved_filepaths_this_request) # 5. --- Response Logic --- if is_first_request: try: print(f"First request for session '{session_key}'. Waiting 10 seconds...") time.sleep(10) print(f"Session '{session_key}' wait time over. Processing...") with session["lock"]: all_filepaths = list(session["files"]) # This is your existing function that returns the list of dictionaries results = process_images(parts_model, damage_model, all_filepaths) # --- *** NEW: DATABASE UPDATE STEP *** --- # After getting results, update the database if results: print(f"Processing database update for session: {session_key}") update_database_for_session(session_key, results) # --- *** END OF NEW STEP *** --- with session["lock"]: session["processed"] = True json_string = json.dumps(results) return Response(json_string, mimetype='application/json') except Exception as e: print(f"An error occurred during processing for session {session_key}: {e}") traceback.print_exc() return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500 finally: if session_key in SESSIONS: with SESSIONS[session_key]["lock"]: all_filepaths_to_delete = list(SESSIONS[session_key]["files"]) for filepath in all_filepaths_to_delete: if os.path.exists(filepath): os.remove(filepath) with SESSIONS_LOCK: del SESSIONS[session_key] print(f"Session '{session_key}' cleaned up.") else: print(f"Subsequent request for session '{session_key}'. Files added. Responding with JSON status.") return jsonify({"status": "aggregated", "message": "File has been added to the processing queue."}) if __name__ == '__main__': # Setting debug=False is recommended for production app.run(host='0.0.0.0', port=7860, debug=True)