SakibAhmed's picture
Upload 2 files
809c35e verified
import os
import torch
from flask import Flask, request, jsonify, render_template, Response
from flask_cors import CORS
from werkzeug.utils import secure_filename
from ultralytics import YOLO
from dotenv import load_dotenv
import time
import threading
import json
import traceback
# --- NEW: Import database driver ---
import psycopg2
import psycopg2.extras
# Import the new processing logic
from processing import process_images
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# --- Session Management ---
SESSIONS = {}
SESSIONS_LOCK = threading.Lock()
# --- Configuration ---
UPLOAD_FOLDER = 'static/uploads'
MODELS_FOLDER = 'models'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# --- Load model names from .env file ---
PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt')
DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt')
# --- NEW: Load Supabase credentials from .env file ---
SUPABASE_HOST = os.getenv('SUPABASE_HOST')
SUPABASE_PORT = os.getenv('SUPABASE_PORT')
SUPABASE_DB = os.getenv('SUPABASE_DB')
SUPABASE_USER = os.getenv('SUPABASE_USER')
SUPABASE_PASSWORD = os.getenv('SUPABASE_PASSWORD')
# --- NEW: Define valid table columns to prevent SQL injection ---
VALID_COLUMNS = [
'alloys', 'dashboard', 'driver_front_side', 'driver_rear_side',
'interior_front', 'passenger_front_side', 'passenger_rear_side',
'service_history', 'tyres'
]
PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME)
DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(MODELS_FOLDER, exist_ok=True)
os.makedirs('templates', exist_ok=True)
# --- Determine Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- Load YOLO Models ---
parts_model, damage_model = None, None
# Load Parts Model
try:
if not os.path.exists(PARTS_MODEL_PATH):
print(f"Warning: Parts model file not found at {PARTS_MODEL_PATH}")
else:
parts_model = YOLO(PARTS_MODEL_PATH)
parts_model.to(device)
print(f"Successfully loaded parts model '{PARTS_MODEL_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Parts Model ({PARTS_MODEL_NAME}): {e}")
# Load Damage Model
try:
if not os.path.exists(DAMAGE_MODEL_PATH):
print(f"Warning: Damage model file not found at {DAMAGE_MODEL_PATH}")
else:
damage_model = YOLO(DAMAGE_MODEL_PATH)
damage_model.to(device)
print(f"Successfully loaded damage model '{DAMAGE_MODEL_NAME}' on {device}.")
except Exception as e:
print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}")
# --- NEW: Database Update Logic ---
# --- CORRECTED: Database Update Logic ---
def update_database_for_session(session_key, results):
"""
Connects to the Supabase database and updates the user_info table.
Args:
session_key (str): The session key to identify the row in user_info.
results (list): A list of prediction dictionaries from the model.
"""
conn = None
try:
# Establish connection
conn = psycopg2.connect(
host=SUPABASE_HOST,
port=SUPABASE_PORT,
dbname=SUPABASE_DB,
user=SUPABASE_USER,
password=SUPABASE_PASSWORD
)
# Use a dictionary cursor to access columns by name
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# 1. Fetch the current state of the row using the correct column 'phone_number'
# --- FIX APPLIED HERE ---
cur.execute("SELECT * FROM user_info WHERE phone_number = %s", (session_key,))
current_row = cur.fetchone()
if not current_row:
print(f"Error: No entry found in user_info for phone_number '{session_key}'")
return
updates_to_make = {}
# 2. Determine what needs to be updated based on the results
for res in results:
part_class = res.get('part_prediction', {}).get('class')
damage_status = res.get('damage_prediction', {}).get('class')
if part_class not in VALID_COLUMNS:
print(f"Warning: Skipping invalid part_class '{part_class}' from prediction.")
continue
current_status = current_row[part_class]
if current_status == 'correct':
continue
if current_status is None or (current_status == 'incorrect' and damage_status == 'correct'):
updates_to_make[part_class] = damage_status
# 3. If there are updates, build and execute a single UPDATE statement
if updates_to_make:
set_clauses = ", ".join([f"{col} = %s" for col in updates_to_make.keys()])
update_values = list(updates_to_make.values())
update_values.append(session_key)
# --- FIX APPLIED HERE ---
update_query = f"UPDATE user_info SET {set_clauses} WHERE phone_number = %s"
print(f"Executing DB Update for session '{session_key}': {updates_to_make}")
cur.execute(update_query, tuple(update_values))
conn.commit()
else:
print(f"No database updates required for session '{session_key}'.")
cur.close()
except (Exception, psycopg2.DatabaseError) as error:
print(f"Database Error for session '{session_key}': {error}")
traceback.print_exc()
finally:
if conn is not None:
conn.close()
def allowed_file(filename):
"""Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/')
def home():
"""Serve the main HTML page."""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""
Endpoint to receive one or more images under a session key.
The first request for a session waits 10 seconds to aggregate images
from subsequent requests, then processes them all.
Subsequent requests for an active session add their images and return a JSON status.
"""
# 1. --- Get Session Key and Validate ---
session_key = request.form.get('session_key')
if not session_key:
return jsonify({"error": "No session_key provided in the payload"}), 400
# 2. --- File Validation ---
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
files = request.files.getlist('file')
if not files or all(f.filename == '' for f in files):
return jsonify({"error": "No selected files"}), 400
# 3. --- Session Handling ---
is_first_request = False
with SESSIONS_LOCK:
if session_key not in SESSIONS:
is_first_request = True
SESSIONS[session_key] = {
"files": [],
"lock": threading.Lock(),
"processed": False
}
session = SESSIONS[session_key]
if session["processed"]:
return jsonify({"status": "complete", "message": "This session has already been processed."})
# 4. --- Save Files for Current Request ---
saved_filepaths_this_request = []
for file in files:
if file and allowed_file(file.filename):
unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
file.save(filepath)
saved_filepaths_this_request.append(filepath)
else:
print(f"Skipped invalid file: {file.filename}")
if not saved_filepaths_this_request:
return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400
with session["lock"]:
if session["processed"]:
for filepath in saved_filepaths_this_request:
if os.path.exists(filepath):
os.remove(filepath)
return jsonify({"status": "complete", "message": "This session has already been processed."})
session["files"].extend(saved_filepaths_this_request)
# 5. --- Response Logic ---
if is_first_request:
try:
print(f"First request for session '{session_key}'. Waiting 10 seconds...")
time.sleep(10)
print(f"Session '{session_key}' wait time over. Processing...")
with session["lock"]:
all_filepaths = list(session["files"])
# This is your existing function that returns the list of dictionaries
results = process_images(parts_model, damage_model, all_filepaths)
# --- *** NEW: DATABASE UPDATE STEP *** ---
# After getting results, update the database
if results:
print(f"Processing database update for session: {session_key}")
update_database_for_session(session_key, results)
# --- *** END OF NEW STEP *** ---
with session["lock"]:
session["processed"] = True
json_string = json.dumps(results)
return Response(json_string, mimetype='application/json')
except Exception as e:
print(f"An error occurred during processing for session {session_key}: {e}")
traceback.print_exc()
return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500
finally:
if session_key in SESSIONS:
with SESSIONS[session_key]["lock"]:
all_filepaths_to_delete = list(SESSIONS[session_key]["files"])
for filepath in all_filepaths_to_delete:
if os.path.exists(filepath):
os.remove(filepath)
with SESSIONS_LOCK:
del SESSIONS[session_key]
print(f"Session '{session_key}' cleaned up.")
else:
print(f"Subsequent request for session '{session_key}'. Files added. Responding with JSON status.")
return jsonify({"status": "aggregated", "message": "File has been added to the processing queue."})
if __name__ == '__main__':
# Setting debug=False is recommended for production
app.run(host='0.0.0.0', port=7860, debug=True)