doniramdani820 commited on
Commit
130249d
Β·
verified Β·
1 Parent(s): 1c55547

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +52 -43
server.py CHANGED
@@ -37,15 +37,10 @@ def verify_api_key(request):
37
  # ============================================================================
38
  # πŸš€ PERFORMANCE OPTIMIZATION CONFIGURATION
39
  # ============================================================================
40
- # CPU Optimization
41
  MAX_WORKERS = min(4, (os.cpu_count() or 1) + 1)
42
  executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
43
-
44
- # Memory Management
45
- MEMORY_THRESHOLD = 80 # Percentage
46
- MAX_CACHE_SIZE = 50 # Maximum number of cached results
47
-
48
- # Model Loading Optimization
49
  os.environ['OMP_NUM_THREADS'] = '2'
50
  os.environ['OPENBLAS_NUM_THREADS'] = '2'
51
  os.environ['MKL_NUM_THREADS'] = '2'
@@ -104,7 +99,6 @@ prediction_cache = {}
104
  cache_lock = threading.Lock()
105
 
106
  def cache_prediction(key, result):
107
- """Cache prediction result with memory management"""
108
  with cache_lock:
109
  if len(prediction_cache) >= MAX_CACHE_SIZE:
110
  oldest_keys = list(prediction_cache.keys())[:MAX_CACHE_SIZE//2]
@@ -113,7 +107,6 @@ def cache_prediction(key, result):
113
  prediction_cache[key] = result
114
 
115
  def get_cached_prediction(key):
116
- """Get cached prediction result"""
117
  with cache_lock:
118
  return prediction_cache.get(key)
119
 
@@ -139,7 +132,6 @@ CLASS_ALIASES = {
139
  # ============================================================================
140
  @lru_cache(maxsize=256)
141
  def normalize_text(text):
142
- """Cached text normalization"""
143
  if not text: return ""
144
  text = text.lower().strip()
145
  text = re.sub(r'[^\w\s]', ' ', text)
@@ -148,7 +140,6 @@ def normalize_text(text):
148
 
149
  @lru_cache(maxsize=256)
150
  def find_class_match(input_text):
151
- """Cached class matching"""
152
  if not input_text: return None
153
  normalized_input = normalize_text(input_text)
154
  for canonical_name, aliases in CLASS_ALIASES.items():
@@ -161,15 +152,13 @@ def find_class_match(input_text):
161
  return None
162
 
163
  # ============================================================================
164
- # πŸ“ MODEL LOADING & INITIALIZATION (OPTIMIZED FOR GUNICORN --PRELOAD)
165
  # ============================================================================
166
- logging.info("πŸ”„ Initializing models at module level for Gunicorn --preload...")
167
-
168
- models = {}
169
  model_class_maps = {}
170
 
171
  def load_yaml_classes(yaml_path):
172
- """Load classes from YAML file"""
173
  try:
174
  with open(yaml_path, 'r', encoding='utf-8') as file:
175
  data = yaml.safe_load(file)
@@ -179,9 +168,6 @@ def load_yaml_classes(yaml_path):
179
  return {}
180
 
181
  try:
182
- models['3x3'] = YOLO('best.onnx', task='classify')
183
- models['4x4'] = YOLO('best4x4.onnx', task='segment')
184
-
185
  for model_type, yaml_file in [('3x3', 'data.yaml'), ('4x4', 'data4x4.yaml')]:
186
  class_map = {}
187
  yaml_classes = load_yaml_classes(yaml_file)
@@ -192,18 +178,48 @@ try:
192
  else:
193
  class_map[class_name.lower()] = class_id
194
  model_class_maps[model_type] = class_map
195
-
196
- logging.info("βœ… Models and class maps initialized successfully!")
197
-
198
  except Exception as e:
199
- logging.error(f"❌ FATAL: Failed to initialize models at module level: {e}")
200
  raise
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # ============================================================================
203
  # πŸ–ΌοΈ OPTIMIZED IMAGE PROCESSING
204
  # ============================================================================
205
  def decode_image_optimized(base64_string):
206
- """Optimized image decoding"""
207
  try:
208
  image_data = base64.b64decode(base64_string.split(',')[1])
209
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
@@ -213,7 +229,6 @@ def decode_image_optimized(base64_string):
213
  return None
214
 
215
  def divide_image_into_4x4_grid(image_cv2):
216
- """Optimized grid division"""
217
  height, width = image_cv2.shape[:2]
218
  grid_height, grid_width = height // 4, width // 4
219
  grid_images, grid_coordinates = [], []
@@ -226,7 +241,6 @@ def divide_image_into_4x4_grid(image_cv2):
226
  return grid_images, grid_coordinates
227
 
228
  def is_object_in_grid_cell(mask_contour, grid_coords, min_coverage_percentage=MIN_COVERAGE_PERCENTAGE):
229
- """Optimized object detection in grid cell"""
230
  x1, y1, x2, y2 = grid_coords
231
  grid_width, grid_height = x2 - x1, y2 - y1
232
  grid_area = grid_width * grid_height
@@ -246,7 +260,6 @@ def is_object_in_grid_cell(mask_contour, grid_coords, min_coverage_percentage=MI
246
  # πŸ”§ UTILITY FUNCTIONS
247
  # ============================================================================
248
  def get_target_class_index(input_title, model_type):
249
- """Get target class index with caching"""
250
  model_classes = model_class_maps.get(model_type, {})
251
  if not input_title or not model_classes: return None
252
  canonical_name = find_class_match(input_title)
@@ -256,7 +269,6 @@ def get_target_class_index(input_title, model_type):
256
  return model_classes.get(normalized_input)
257
 
258
  def memory_cleanup():
259
- """Perform memory cleanup"""
260
  gc.collect()
261
  current_memory = psutil.virtual_memory().percent
262
  if current_memory > MEMORY_THRESHOLD:
@@ -267,7 +279,6 @@ def memory_cleanup():
267
  # ============================================================================
268
  @app.before_request
269
  def check_api_key():
270
- """Verify API key for all requests except health check"""
271
  if request.endpoint in ['health', 'stats']: return
272
  if not verify_api_key(request):
273
  return jsonify({"error": "Invalid or missing API key"}), 401
@@ -278,8 +289,10 @@ def check_api_key():
278
  @app.route('/health', methods=['GET'])
279
  def health():
280
  return jsonify({
281
- "status": "healthy", "models_loaded": len(models),
282
- "memory_usage": psutil.virtual_memory().percent, "cpu_usage": psutil.cpu_percent()
 
 
283
  })
284
 
285
  @app.route('/stats', methods=['GET'])
@@ -288,15 +301,15 @@ def stats():
288
 
289
  @app.route('/predict', methods=['POST'])
290
  def predict():
291
- """Optimized 3x3 prediction endpoint"""
292
  import time
293
  start_time = time.time()
294
-
295
  try:
296
  data = request.get_json(silent=True)
297
  if not data: return jsonify({"error": "Invalid request body"}), 400
298
- model = models.get('3x3')
 
299
  if not model: return jsonify({"error": "3x3 model not loaded"}), 500
 
300
  input_title = data.get('title', '')
301
  target_class_index = get_target_class_index(input_title, '3x3')
302
  if target_class_index is None:
@@ -314,12 +327,9 @@ def predict():
314
  image = decode_image_optimized(item['base64'])
315
  if image is None: return None
316
  results = model(image, verbose=False)
317
-
318
- # Robust validation to prevent 'NoneType' error
319
  if not results: return None
320
  res = results[0]
321
  if res.probs is None or res.probs.data is None: return None
322
-
323
  confidence = res.probs.data[target_class_index].item()
324
  return {'index': item['index'], 'confidence': confidence, 'selected': confidence >= CONFIDENCE_THRESHOLD_3X3}
325
  except Exception as e:
@@ -338,20 +348,20 @@ def predict():
338
  if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
339
  return jsonify(response)
340
  except Exception as e:
341
- logging.error(f"Error in /predict: {e}")
342
  return jsonify({"error": "Internal server error"}), 500
343
 
344
  @app.route('/predict_4x4', methods=['POST'])
345
  def predict_4x4():
346
- """Optimized 4x4 prediction endpoint"""
347
  import time
348
  start_time = time.time()
349
-
350
  try:
351
  data = request.get_json(silent=True)
352
  if not data: return jsonify({"error": "Invalid request body"}), 400
353
- model = models.get('4x4')
 
354
  if not model: return jsonify({"error": "4x4 model not loaded"}), 500
 
355
  input_title = data.get('title', '')
356
  target_class_index = get_target_class_index(input_title, '4x4')
357
  if target_class_index is None:
@@ -367,7 +377,6 @@ def predict_4x4():
367
  image_pil = decode_image_optimized(data['image_b64'])
368
  if image_pil is None: return jsonify({"error": "Invalid image data"}), 400
369
  image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
370
- grid_images, grid_coordinates = divide_image_into_4x4_grid(image_cv2)
371
 
372
  results = model(image_cv2, verbose=False)
373
  indices_to_click = []
@@ -390,7 +399,7 @@ def predict_4x4():
390
  if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
391
  return jsonify(response)
392
  except Exception as e:
393
- logging.error(f"Error in /predict_4x4: {e}")
394
  return jsonify({"error": "Internal server error"}), 500
395
 
396
  @app.route('/classes', methods=['GET'])
 
37
  # ============================================================================
38
  # πŸš€ PERFORMANCE OPTIMIZATION CONFIGURATION
39
  # ============================================================================
 
40
  MAX_WORKERS = min(4, (os.cpu_count() or 1) + 1)
41
  executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
42
+ MEMORY_THRESHOLD = 80
43
+ MAX_CACHE_SIZE = 50
 
 
 
 
44
  os.environ['OMP_NUM_THREADS'] = '2'
45
  os.environ['OPENBLAS_NUM_THREADS'] = '2'
46
  os.environ['MKL_NUM_THREADS'] = '2'
 
99
  cache_lock = threading.Lock()
100
 
101
  def cache_prediction(key, result):
 
102
  with cache_lock:
103
  if len(prediction_cache) >= MAX_CACHE_SIZE:
104
  oldest_keys = list(prediction_cache.keys())[:MAX_CACHE_SIZE//2]
 
107
  prediction_cache[key] = result
108
 
109
  def get_cached_prediction(key):
 
110
  with cache_lock:
111
  return prediction_cache.get(key)
112
 
 
132
  # ============================================================================
133
  @lru_cache(maxsize=256)
134
  def normalize_text(text):
 
135
  if not text: return ""
136
  text = text.lower().strip()
137
  text = re.sub(r'[^\w\s]', ' ', text)
 
140
 
141
  @lru_cache(maxsize=256)
142
  def find_class_match(input_text):
 
143
  if not input_text: return None
144
  normalized_input = normalize_text(input_text)
145
  for canonical_name, aliases in CLASS_ALIASES.items():
 
152
  return None
153
 
154
  # ============================================================================
155
+ # πŸ“ MODEL LOADING & INITIALIZATION (STRATEGI: LAZY PER-WORKER)
156
  # ============================================================================
157
+ _worker_models = {}
158
+ _model_lock = threading.Lock()
 
159
  model_class_maps = {}
160
 
161
  def load_yaml_classes(yaml_path):
 
162
  try:
163
  with open(yaml_path, 'r', encoding='utf-8') as file:
164
  data = yaml.safe_load(file)
 
168
  return {}
169
 
170
  try:
 
 
 
171
  for model_type, yaml_file in [('3x3', 'data.yaml'), ('4x4', 'data4x4.yaml')]:
172
  class_map = {}
173
  yaml_classes = load_yaml_classes(yaml_file)
 
178
  else:
179
  class_map[class_name.lower()] = class_id
180
  model_class_maps[model_type] = class_map
181
+ logging.info("βœ… Class maps loaded successfully!")
 
 
182
  except Exception as e:
183
+ logging.error(f"❌ FATAL: Failed to initialize class maps: {e}")
184
  raise
185
 
186
+
187
+ def get_model(model_type: str):
188
+ """
189
+ Loads a model only once per worker process (lazy initialization).
190
+ This is the robust solution for multi-process servers like Gunicorn.
191
+ """
192
+ if model_type in _worker_models:
193
+ return _worker_models[model_type]
194
+
195
+ with _model_lock:
196
+ if model_type in _worker_models:
197
+ return _worker_models[model_type]
198
+
199
+ logging.info(f"WORKER_INIT: Loading model '{model_type}' for worker PID: {os.getpid()}...")
200
+
201
+ model_path, task_type = '', ''
202
+ if model_type == '3x3':
203
+ model_path, task_type = 'best.onnx', 'classify'
204
+ elif model_type == '4x4':
205
+ model_path, task_type = 'best4x4.onnx', 'segment'
206
+ else:
207
+ logging.error(f"Attempted to load unknown model type: {model_type}")
208
+ return None
209
+
210
+ try:
211
+ model = YOLO(model_path, task=task_type)
212
+ _worker_models[model_type] = model
213
+ logging.info(f"WORKER_INIT: Model '{model_type}' loaded successfully for worker PID: {os.getpid()}.")
214
+ return model
215
+ except Exception as e:
216
+ logging.error(f"WORKER_INIT: Failed to load model '{model_path}' for worker PID: {os.getpid()}: {e}")
217
+ return None
218
+
219
  # ============================================================================
220
  # πŸ–ΌοΈ OPTIMIZED IMAGE PROCESSING
221
  # ============================================================================
222
  def decode_image_optimized(base64_string):
 
223
  try:
224
  image_data = base64.b64decode(base64_string.split(',')[1])
225
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
229
  return None
230
 
231
  def divide_image_into_4x4_grid(image_cv2):
 
232
  height, width = image_cv2.shape[:2]
233
  grid_height, grid_width = height // 4, width // 4
234
  grid_images, grid_coordinates = [], []
 
241
  return grid_images, grid_coordinates
242
 
243
  def is_object_in_grid_cell(mask_contour, grid_coords, min_coverage_percentage=MIN_COVERAGE_PERCENTAGE):
 
244
  x1, y1, x2, y2 = grid_coords
245
  grid_width, grid_height = x2 - x1, y2 - y1
246
  grid_area = grid_width * grid_height
 
260
  # πŸ”§ UTILITY FUNCTIONS
261
  # ============================================================================
262
  def get_target_class_index(input_title, model_type):
 
263
  model_classes = model_class_maps.get(model_type, {})
264
  if not input_title or not model_classes: return None
265
  canonical_name = find_class_match(input_title)
 
269
  return model_classes.get(normalized_input)
270
 
271
  def memory_cleanup():
 
272
  gc.collect()
273
  current_memory = psutil.virtual_memory().percent
274
  if current_memory > MEMORY_THRESHOLD:
 
279
  # ============================================================================
280
  @app.before_request
281
  def check_api_key():
 
282
  if request.endpoint in ['health', 'stats']: return
283
  if not verify_api_key(request):
284
  return jsonify({"error": "Invalid or missing API key"}), 401
 
289
  @app.route('/health', methods=['GET'])
290
  def health():
291
  return jsonify({
292
+ "status": "healthy",
293
+ "models_loaded_in_worker": len(_worker_models),
294
+ "memory_usage": psutil.virtual_memory().percent,
295
+ "cpu_usage": psutil.cpu_percent()
296
  })
297
 
298
  @app.route('/stats', methods=['GET'])
 
301
 
302
  @app.route('/predict', methods=['POST'])
303
  def predict():
 
304
  import time
305
  start_time = time.time()
 
306
  try:
307
  data = request.get_json(silent=True)
308
  if not data: return jsonify({"error": "Invalid request body"}), 400
309
+
310
+ model = get_model('3x3')
311
  if not model: return jsonify({"error": "3x3 model not loaded"}), 500
312
+
313
  input_title = data.get('title', '')
314
  target_class_index = get_target_class_index(input_title, '3x3')
315
  if target_class_index is None:
 
327
  image = decode_image_optimized(item['base64'])
328
  if image is None: return None
329
  results = model(image, verbose=False)
 
 
330
  if not results: return None
331
  res = results[0]
332
  if res.probs is None or res.probs.data is None: return None
 
333
  confidence = res.probs.data[target_class_index].item()
334
  return {'index': item['index'], 'confidence': confidence, 'selected': confidence >= CONFIDENCE_THRESHOLD_3X3}
335
  except Exception as e:
 
348
  if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
349
  return jsonify(response)
350
  except Exception as e:
351
+ logging.error(f"Error in /predict: {e}", exc_info=True)
352
  return jsonify({"error": "Internal server error"}), 500
353
 
354
  @app.route('/predict_4x4', methods=['POST'])
355
  def predict_4x4():
 
356
  import time
357
  start_time = time.time()
 
358
  try:
359
  data = request.get_json(silent=True)
360
  if not data: return jsonify({"error": "Invalid request body"}), 400
361
+
362
+ model = get_model('4x4')
363
  if not model: return jsonify({"error": "4x4 model not loaded"}), 500
364
+
365
  input_title = data.get('title', '')
366
  target_class_index = get_target_class_index(input_title, '4x4')
367
  if target_class_index is None:
 
377
  image_pil = decode_image_optimized(data['image_b64'])
378
  if image_pil is None: return jsonify({"error": "Invalid image data"}), 400
379
  image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
 
380
 
381
  results = model(image_cv2, verbose=False)
382
  indices_to_click = []
 
399
  if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
400
  return jsonify(response)
401
  except Exception as e:
402
+ logging.error(f"Error in /predict_4x4: {e}", exc_info=True)
403
  return jsonify({"error": "Internal server error"}), 500
404
 
405
  @app.route('/classes', methods=['GET'])