SakibAhmed commited on
Commit
809c35e
·
verified ·
1 Parent(s): 9fbd05b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +192 -34
  2. requirements.txt +8 -7
app.py CHANGED
@@ -1,12 +1,19 @@
1
- # app.py
2
-
3
  import os
4
  import torch
5
- from flask import Flask, request, jsonify, render_template
6
  from flask_cors import CORS
7
  from werkzeug.utils import secure_filename
8
  from ultralytics import YOLO
9
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
10
 
11
  # Import the new processing logic
12
  from processing import process_images
@@ -19,16 +26,34 @@ app = Flask(__name__)
19
  # Enable CORS for all routes
20
  CORS(app)
21
 
 
 
 
 
22
  # --- Configuration ---
23
  UPLOAD_FOLDER = 'static/uploads'
24
  MODELS_FOLDER = 'models'
25
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
26
 
27
  # --- Load model names from .env file ---
28
- # Updated names to be more descriptive
29
  PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt')
30
  DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt')
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME)
33
  DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME)
34
 
@@ -67,6 +92,81 @@ except Exception as e:
67
  print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}")
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def allowed_file(filename):
71
  """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
72
  return '.' in filename and \
@@ -80,50 +180,108 @@ def home():
80
  @app.route('/predict', methods=['POST'])
81
  def predict():
82
  """
83
- Endpoint to receive one or more images, run the two-step prediction,
84
- and return the combined results.
 
 
85
  """
86
- # 1. --- File Validation for Multiple Files ---
 
 
 
 
 
87
  if 'file' not in request.files:
88
  return jsonify({"error": "No file part in the request"}), 400
89
-
90
  files = request.files.getlist('file')
91
-
92
  if not files or all(f.filename == '' for f in files):
93
  return jsonify({"error": "No selected files"}), 400
94
 
95
- saved_filepaths = []
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  for file in files:
98
  if file and allowed_file(file.filename):
99
- filename = secure_filename(file.filename)
100
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
101
  file.save(filepath)
102
- saved_filepaths.append(filepath)
103
  else:
104
- # You might want to log this or inform the user about skipped files
105
  print(f"Skipped invalid file: {file.filename}")
106
 
107
- if not saved_filepaths:
108
  return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400
109
-
110
- # 2. --- Perform Inference ---
111
- try:
112
- # Pass the models and file paths to the processing function
113
- results = process_images(parts_model, damage_model, saved_filepaths)
114
- return jsonify(results)
115
-
116
- except Exception as e:
117
- # Log the full error for debugging
118
- print(f"An error occurred during processing: {e}")
119
- import traceback
120
- traceback.print_exc()
121
- return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500
122
- finally:
123
- # 3. --- Cleanup ---
124
- for filepath in saved_filepaths:
125
- if os.path.exists(filepath):
126
- os.remove(filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  if __name__ == '__main__':
129
  # Setting debug=False is recommended for production
 
 
 
1
  import os
2
  import torch
3
+ from flask import Flask, request, jsonify, render_template, Response
4
  from flask_cors import CORS
5
  from werkzeug.utils import secure_filename
6
  from ultralytics import YOLO
7
  from dotenv import load_dotenv
8
+ import time
9
+ import threading
10
+ import json
11
+ import traceback
12
+
13
+ # --- NEW: Import database driver ---
14
+ import psycopg2
15
+ import psycopg2.extras
16
+
17
 
18
  # Import the new processing logic
19
  from processing import process_images
 
26
  # Enable CORS for all routes
27
  CORS(app)
28
 
29
+ # --- Session Management ---
30
+ SESSIONS = {}
31
+ SESSIONS_LOCK = threading.Lock()
32
+
33
  # --- Configuration ---
34
  UPLOAD_FOLDER = 'static/uploads'
35
  MODELS_FOLDER = 'models'
36
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
37
 
38
  # --- Load model names from .env file ---
 
39
  PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt')
40
  DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt')
41
 
42
+ # --- NEW: Load Supabase credentials from .env file ---
43
+ SUPABASE_HOST = os.getenv('SUPABASE_HOST')
44
+ SUPABASE_PORT = os.getenv('SUPABASE_PORT')
45
+ SUPABASE_DB = os.getenv('SUPABASE_DB')
46
+ SUPABASE_USER = os.getenv('SUPABASE_USER')
47
+ SUPABASE_PASSWORD = os.getenv('SUPABASE_PASSWORD')
48
+
49
+ # --- NEW: Define valid table columns to prevent SQL injection ---
50
+ VALID_COLUMNS = [
51
+ 'alloys', 'dashboard', 'driver_front_side', 'driver_rear_side',
52
+ 'interior_front', 'passenger_front_side', 'passenger_rear_side',
53
+ 'service_history', 'tyres'
54
+ ]
55
+
56
+
57
  PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME)
58
  DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME)
59
 
 
92
  print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}")
93
 
94
 
95
+ # --- NEW: Database Update Logic ---
96
+ # --- CORRECTED: Database Update Logic ---
97
+ def update_database_for_session(session_key, results):
98
+ """
99
+ Connects to the Supabase database and updates the user_info table.
100
+
101
+ Args:
102
+ session_key (str): The session key to identify the row in user_info.
103
+ results (list): A list of prediction dictionaries from the model.
104
+ """
105
+ conn = None
106
+ try:
107
+ # Establish connection
108
+ conn = psycopg2.connect(
109
+ host=SUPABASE_HOST,
110
+ port=SUPABASE_PORT,
111
+ dbname=SUPABASE_DB,
112
+ user=SUPABASE_USER,
113
+ password=SUPABASE_PASSWORD
114
+ )
115
+ # Use a dictionary cursor to access columns by name
116
+ cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
117
+
118
+ # 1. Fetch the current state of the row using the correct column 'phone_number'
119
+ # --- FIX APPLIED HERE ---
120
+ cur.execute("SELECT * FROM user_info WHERE phone_number = %s", (session_key,))
121
+ current_row = cur.fetchone()
122
+
123
+ if not current_row:
124
+ print(f"Error: No entry found in user_info for phone_number '{session_key}'")
125
+ return
126
+
127
+ updates_to_make = {}
128
+ # 2. Determine what needs to be updated based on the results
129
+ for res in results:
130
+ part_class = res.get('part_prediction', {}).get('class')
131
+ damage_status = res.get('damage_prediction', {}).get('class')
132
+
133
+ if part_class not in VALID_COLUMNS:
134
+ print(f"Warning: Skipping invalid part_class '{part_class}' from prediction.")
135
+ continue
136
+
137
+ current_status = current_row[part_class]
138
+
139
+ if current_status == 'correct':
140
+ continue
141
+ if current_status is None or (current_status == 'incorrect' and damage_status == 'correct'):
142
+ updates_to_make[part_class] = damage_status
143
+
144
+ # 3. If there are updates, build and execute a single UPDATE statement
145
+ if updates_to_make:
146
+ set_clauses = ", ".join([f"{col} = %s" for col in updates_to_make.keys()])
147
+ update_values = list(updates_to_make.values())
148
+ update_values.append(session_key)
149
+
150
+ # --- FIX APPLIED HERE ---
151
+ update_query = f"UPDATE user_info SET {set_clauses} WHERE phone_number = %s"
152
+
153
+ print(f"Executing DB Update for session '{session_key}': {updates_to_make}")
154
+ cur.execute(update_query, tuple(update_values))
155
+
156
+ conn.commit()
157
+ else:
158
+ print(f"No database updates required for session '{session_key}'.")
159
+
160
+ cur.close()
161
+
162
+ except (Exception, psycopg2.DatabaseError) as error:
163
+ print(f"Database Error for session '{session_key}': {error}")
164
+ traceback.print_exc()
165
+ finally:
166
+ if conn is not None:
167
+ conn.close()
168
+
169
+
170
  def allowed_file(filename):
171
  """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
172
  return '.' in filename and \
 
180
  @app.route('/predict', methods=['POST'])
181
  def predict():
182
  """
183
+ Endpoint to receive one or more images under a session key.
184
+ The first request for a session waits 10 seconds to aggregate images
185
+ from subsequent requests, then processes them all.
186
+ Subsequent requests for an active session add their images and return a JSON status.
187
  """
188
+ # 1. --- Get Session Key and Validate ---
189
+ session_key = request.form.get('session_key')
190
+ if not session_key:
191
+ return jsonify({"error": "No session_key provided in the payload"}), 400
192
+
193
+ # 2. --- File Validation ---
194
  if 'file' not in request.files:
195
  return jsonify({"error": "No file part in the request"}), 400
196
+
197
  files = request.files.getlist('file')
 
198
  if not files or all(f.filename == '' for f in files):
199
  return jsonify({"error": "No selected files"}), 400
200
 
201
+ # 3. --- Session Handling ---
202
+ is_first_request = False
203
+ with SESSIONS_LOCK:
204
+ if session_key not in SESSIONS:
205
+ is_first_request = True
206
+ SESSIONS[session_key] = {
207
+ "files": [],
208
+ "lock": threading.Lock(),
209
+ "processed": False
210
+ }
211
+
212
+ session = SESSIONS[session_key]
213
+
214
+ if session["processed"]:
215
+ return jsonify({"status": "complete", "message": "This session has already been processed."})
216
+
217
+ # 4. --- Save Files for Current Request ---
218
+ saved_filepaths_this_request = []
219
  for file in files:
220
  if file and allowed_file(file.filename):
221
+ unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}"
222
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
223
  file.save(filepath)
224
+ saved_filepaths_this_request.append(filepath)
225
  else:
 
226
  print(f"Skipped invalid file: {file.filename}")
227
 
228
+ if not saved_filepaths_this_request:
229
  return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400
230
+
231
+ with session["lock"]:
232
+ if session["processed"]:
233
+ for filepath in saved_filepaths_this_request:
234
+ if os.path.exists(filepath):
235
+ os.remove(filepath)
236
+ return jsonify({"status": "complete", "message": "This session has already been processed."})
237
+ session["files"].extend(saved_filepaths_this_request)
238
+
239
+ # 5. --- Response Logic ---
240
+ if is_first_request:
241
+ try:
242
+ print(f"First request for session '{session_key}'. Waiting 10 seconds...")
243
+ time.sleep(10)
244
+ print(f"Session '{session_key}' wait time over. Processing...")
245
+
246
+ with session["lock"]:
247
+ all_filepaths = list(session["files"])
248
+
249
+ # This is your existing function that returns the list of dictionaries
250
+ results = process_images(parts_model, damage_model, all_filepaths)
251
+
252
+ # --- *** NEW: DATABASE UPDATE STEP *** ---
253
+ # After getting results, update the database
254
+ if results:
255
+ print(f"Processing database update for session: {session_key}")
256
+ update_database_for_session(session_key, results)
257
+ # --- *** END OF NEW STEP *** ---
258
+
259
+ with session["lock"]:
260
+ session["processed"] = True
261
+
262
+ json_string = json.dumps(results)
263
+ return Response(json_string, mimetype='application/json')
264
+
265
+ except Exception as e:
266
+ print(f"An error occurred during processing for session {session_key}: {e}")
267
+ traceback.print_exc()
268
+ return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500
269
+ finally:
270
+ if session_key in SESSIONS:
271
+ with SESSIONS[session_key]["lock"]:
272
+ all_filepaths_to_delete = list(SESSIONS[session_key]["files"])
273
+
274
+ for filepath in all_filepaths_to_delete:
275
+ if os.path.exists(filepath):
276
+ os.remove(filepath)
277
+
278
+ with SESSIONS_LOCK:
279
+ del SESSIONS[session_key]
280
+ print(f"Session '{session_key}' cleaned up.")
281
+ else:
282
+ print(f"Subsequent request for session '{session_key}'. Files added. Responding with JSON status.")
283
+ return jsonify({"status": "aggregated", "message": "File has been added to the processing queue."})
284
+
285
 
286
  if __name__ == '__main__':
287
  # Setting debug=False is recommended for production
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
- Flask==3.1.1
2
- flask_cors==5.0.1
3
- python-dotenv==1.1.0
4
- torch
5
- ultralytics==8.3.151
6
- Werkzeug==3.1.3
7
- opencv-python-headless==4.10.0.84
 
 
1
+ Flask==3.1.1
2
+ flask_cors==5.0.1
3
+ python-dotenv==1.1.0
4
+ torch
5
+ ultralytics==8.3.151
6
+ Werkzeug==3.1.3
7
+ opencv-python-headless==4.10.0.84
8
+ psycopg2-binary==2.9.10