Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from flask import Flask, request, jsonify, render_template, Response | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| from ultralytics import YOLO | |
| from dotenv import load_dotenv | |
| import time | |
| import threading | |
| import json | |
| import traceback | |
| # --- NEW: Import database driver --- | |
| import psycopg2 | |
| import psycopg2.extras | |
| # Import the new processing logic | |
| from processing import process_images | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = Flask(__name__) | |
| # Enable CORS for all routes | |
| CORS(app) | |
| # --- Session Management --- | |
| SESSIONS = {} | |
| SESSIONS_LOCK = threading.Lock() | |
| # --- Configuration --- | |
| UPLOAD_FOLDER = 'static/uploads' | |
| MODELS_FOLDER = 'models' | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| # --- Load model names from .env file --- | |
| PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt') | |
| DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt') | |
| # --- NEW: Load Supabase credentials from .env file --- | |
| SUPABASE_HOST = os.getenv('SUPABASE_HOST') | |
| SUPABASE_PORT = os.getenv('SUPABASE_PORT') | |
| SUPABASE_DB = os.getenv('SUPABASE_DB') | |
| SUPABASE_USER = os.getenv('SUPABASE_USER') | |
| SUPABASE_PASSWORD = os.getenv('SUPABASE_PASSWORD') | |
| # --- NEW: Define valid table columns to prevent SQL injection --- | |
| VALID_COLUMNS = [ | |
| 'alloys', 'dashboard', 'driver_front_side', 'driver_rear_side', | |
| 'interior_front', 'passenger_front_side', 'passenger_rear_side', | |
| 'service_history', 'tyres' | |
| ] | |
| PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME) | |
| DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(MODELS_FOLDER, exist_ok=True) | |
| os.makedirs('templates', exist_ok=True) | |
| # --- Determine Device --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # --- Load YOLO Models --- | |
| parts_model, damage_model = None, None | |
| # Load Parts Model | |
| try: | |
| if not os.path.exists(PARTS_MODEL_PATH): | |
| print(f"Warning: Parts model file not found at {PARTS_MODEL_PATH}") | |
| else: | |
| parts_model = YOLO(PARTS_MODEL_PATH) | |
| parts_model.to(device) | |
| print(f"Successfully loaded parts model '{PARTS_MODEL_NAME}' on {device}.") | |
| except Exception as e: | |
| print(f"Error loading Parts Model ({PARTS_MODEL_NAME}): {e}") | |
| # Load Damage Model | |
| try: | |
| if not os.path.exists(DAMAGE_MODEL_PATH): | |
| print(f"Warning: Damage model file not found at {DAMAGE_MODEL_PATH}") | |
| else: | |
| damage_model = YOLO(DAMAGE_MODEL_PATH) | |
| damage_model.to(device) | |
| print(f"Successfully loaded damage model '{DAMAGE_MODEL_NAME}' on {device}.") | |
| except Exception as e: | |
| print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}") | |
| # --- NEW: Database Update Logic --- | |
| # --- CORRECTED: Database Update Logic --- | |
| def update_database_for_session(session_key, results): | |
| """ | |
| Connects to the Supabase database and updates the user_info table. | |
| Args: | |
| session_key (str): The session key to identify the row in user_info. | |
| results (list): A list of prediction dictionaries from the model. | |
| """ | |
| conn = None | |
| try: | |
| # Establish connection | |
| conn = psycopg2.connect( | |
| host=SUPABASE_HOST, | |
| port=SUPABASE_PORT, | |
| dbname=SUPABASE_DB, | |
| user=SUPABASE_USER, | |
| password=SUPABASE_PASSWORD | |
| ) | |
| # Use a dictionary cursor to access columns by name | |
| cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) | |
| # 1. Fetch the current state of the row using the correct column 'phone_number' | |
| # --- FIX APPLIED HERE --- | |
| cur.execute("SELECT * FROM user_info WHERE phone_number = %s", (session_key,)) | |
| current_row = cur.fetchone() | |
| if not current_row: | |
| print(f"Error: No entry found in user_info for phone_number '{session_key}'") | |
| return | |
| updates_to_make = {} | |
| # 2. Determine what needs to be updated based on the results | |
| for res in results: | |
| part_class = res.get('part_prediction', {}).get('class') | |
| damage_status = res.get('damage_prediction', {}).get('class') | |
| if part_class not in VALID_COLUMNS: | |
| print(f"Warning: Skipping invalid part_class '{part_class}' from prediction.") | |
| continue | |
| current_status = current_row[part_class] | |
| if current_status == 'correct': | |
| continue | |
| if current_status is None or (current_status == 'incorrect' and damage_status == 'correct'): | |
| updates_to_make[part_class] = damage_status | |
| # 3. If there are updates, build and execute a single UPDATE statement | |
| if updates_to_make: | |
| set_clauses = ", ".join([f"{col} = %s" for col in updates_to_make.keys()]) | |
| update_values = list(updates_to_make.values()) | |
| update_values.append(session_key) | |
| # --- FIX APPLIED HERE --- | |
| update_query = f"UPDATE user_info SET {set_clauses} WHERE phone_number = %s" | |
| print(f"Executing DB Update for session '{session_key}': {updates_to_make}") | |
| cur.execute(update_query, tuple(update_values)) | |
| conn.commit() | |
| else: | |
| print(f"No database updates required for session '{session_key}'.") | |
| cur.close() | |
| except (Exception, psycopg2.DatabaseError) as error: | |
| print(f"Database Error for session '{session_key}': {error}") | |
| traceback.print_exc() | |
| finally: | |
| if conn is not None: | |
| conn.close() | |
| def allowed_file(filename): | |
| """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def home(): | |
| """Serve the main HTML page.""" | |
| return render_template('index.html') | |
| def predict(): | |
| """ | |
| Endpoint to receive one or more images under a session key. | |
| The first request for a session waits 10 seconds to aggregate images | |
| from subsequent requests, then processes them all. | |
| Subsequent requests for an active session add their images and return a JSON status. | |
| """ | |
| # 1. --- Get Session Key and Validate --- | |
| session_key = request.form.get('session_key') | |
| if not session_key: | |
| return jsonify({"error": "No session_key provided in the payload"}), 400 | |
| # 2. --- File Validation --- | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part in the request"}), 400 | |
| files = request.files.getlist('file') | |
| if not files or all(f.filename == '' for f in files): | |
| return jsonify({"error": "No selected files"}), 400 | |
| # 3. --- Session Handling --- | |
| is_first_request = False | |
| with SESSIONS_LOCK: | |
| if session_key not in SESSIONS: | |
| is_first_request = True | |
| SESSIONS[session_key] = { | |
| "files": [], | |
| "lock": threading.Lock(), | |
| "processed": False | |
| } | |
| session = SESSIONS[session_key] | |
| if session["processed"]: | |
| return jsonify({"status": "complete", "message": "This session has already been processed."}) | |
| # 4. --- Save Files for Current Request --- | |
| saved_filepaths_this_request = [] | |
| for file in files: | |
| if file and allowed_file(file.filename): | |
| unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}" | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) | |
| file.save(filepath) | |
| saved_filepaths_this_request.append(filepath) | |
| else: | |
| print(f"Skipped invalid file: {file.filename}") | |
| if not saved_filepaths_this_request: | |
| return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400 | |
| with session["lock"]: | |
| if session["processed"]: | |
| for filepath in saved_filepaths_this_request: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| return jsonify({"status": "complete", "message": "This session has already been processed."}) | |
| session["files"].extend(saved_filepaths_this_request) | |
| # 5. --- Response Logic --- | |
| if is_first_request: | |
| try: | |
| print(f"First request for session '{session_key}'. Waiting 10 seconds...") | |
| time.sleep(10) | |
| print(f"Session '{session_key}' wait time over. Processing...") | |
| with session["lock"]: | |
| all_filepaths = list(session["files"]) | |
| # This is your existing function that returns the list of dictionaries | |
| results = process_images(parts_model, damage_model, all_filepaths) | |
| # --- *** NEW: DATABASE UPDATE STEP *** --- | |
| # After getting results, update the database | |
| if results: | |
| print(f"Processing database update for session: {session_key}") | |
| update_database_for_session(session_key, results) | |
| # --- *** END OF NEW STEP *** --- | |
| with session["lock"]: | |
| session["processed"] = True | |
| json_string = json.dumps(results) | |
| return Response(json_string, mimetype='application/json') | |
| except Exception as e: | |
| print(f"An error occurred during processing for session {session_key}: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500 | |
| finally: | |
| if session_key in SESSIONS: | |
| with SESSIONS[session_key]["lock"]: | |
| all_filepaths_to_delete = list(SESSIONS[session_key]["files"]) | |
| for filepath in all_filepaths_to_delete: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| with SESSIONS_LOCK: | |
| del SESSIONS[session_key] | |
| print(f"Session '{session_key}' cleaned up.") | |
| else: | |
| print(f"Subsequent request for session '{session_key}'. Files added. Responding with JSON status.") | |
| return jsonify({"status": "aggregated", "message": "File has been added to the processing queue."}) | |
| if __name__ == '__main__': | |
| # Setting debug=False is recommended for production | |
| app.run(host='0.0.0.0', port=7860, debug=True) |