DraconicDragon commited on
Commit
d8c89a8
·
verified ·
1 Parent(s): 8e0b817

attempt to make space run on cpu

Browse files

removed GPU decorator from predict_onnx()

dont try to load model every time, use global onnx session var

explicitly setting queue params demo.queue(concurrency_count=1, max_size=None, status_update_rate=1)

Files changed (1) hide show
  1. app.py +30 -5
app.py CHANGED
@@ -286,9 +286,13 @@ g_idx_to_tag = None
286
  g_tag_to_category = None
287
  g_current_model = None
288
 
 
 
 
289
  # --- Initialization Function ---
290
  def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
291
  global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
 
292
 
293
  if not model_choice in MODEL_OPTIONS:
294
  print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
@@ -301,19 +305,40 @@ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
301
 
302
  print(f"Initializing ONNX paths and labels for model: {model_choice}...")
303
  hf_token = os.environ.get("HF_TOKEN")
 
304
  try:
305
  print(f"Attempting to download ONNX model: {onnx_filename}")
306
- g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
 
 
 
 
 
 
307
  print(f"ONNX model path: {g_onnx_model_path}")
308
 
309
  print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
310
- g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
 
 
 
 
 
 
311
  print(f"Tag mapping path: {g_tag_mapping_path}")
312
 
313
  print("Loading labels from mapping...")
314
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
315
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
316
-
 
 
 
 
 
 
 
 
317
  return True
318
 
319
  except Exception as e:
@@ -341,7 +366,6 @@ def change_model(model_choice):
341
  return f"Error changing model: {str(e)}"
342
 
343
  # --- Main Prediction Function (ONNX) ---
344
- @spaces.GPU()
345
  def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
346
  print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
347
 
@@ -371,7 +395,7 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
371
  providers.append('CUDAExecutionProvider')
372
  providers.append('CPUExecutionProvider')
373
  print(f"Attempting to load session with providers: {providers}")
374
- session = ort.InferenceSession(g_onnx_model_path, providers=providers)
375
  print(f"ONNX session loaded using: {session.get_providers()[0]}")
376
  except Exception as e:
377
  message = f"Error loading ONNX session in worker: {e}"
@@ -544,5 +568,6 @@ if __name__ == "__main__":
544
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
545
  # Initialize paths and labels at startup (with default model)
546
  initialize_onnx_paths(DEFAULT_MODEL)
 
547
  # Launch Gradio app
548
  demo.launch()
 
286
  g_tag_to_category = None
287
  g_current_model = None
288
 
289
+ # --- Global ONNX session ---
290
+ g_session = None
291
+
292
  # --- Initialization Function ---
293
  def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
294
  global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
295
+ global g_session
296
 
297
  if not model_choice in MODEL_OPTIONS:
298
  print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
 
305
 
306
  print(f"Initializing ONNX paths and labels for model: {model_choice}...")
307
  hf_token = os.environ.get("HF_TOKEN")
308
+
309
  try:
310
  print(f"Attempting to download ONNX model: {onnx_filename}")
311
+ g_onnx_model_path = hf_hub_download(
312
+ repo_id=REPO_ID,
313
+ filename=onnx_filename,
314
+ cache_dir=CACHE_DIR,
315
+ token=hf_token,
316
+ force_download=False
317
+ )
318
  print(f"ONNX model path: {g_onnx_model_path}")
319
 
320
  print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
321
+ g_tag_mapping_path = hf_hub_download(
322
+ repo_id=REPO_ID,
323
+ filename=tag_mapping_filename,
324
+ cache_dir=CACHE_DIR,
325
+ token=hf_token,
326
+ force_download=False
327
+ )
328
  print(f"Tag mapping path: {g_tag_mapping_path}")
329
 
330
  print("Loading labels from mapping...")
331
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
332
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
333
+
334
+ # Load ONNX session ONCE here
335
+ print("Creating ONNX Runtime session (CPUExecutionProvider)...")
336
+ g_session = ort.InferenceSession(
337
+ g_onnx_model_path,
338
+ providers=["CPUExecutionProvider"]
339
+ )
340
+ print("ONNX Runtime session ready.")
341
+
342
  return True
343
 
344
  except Exception as e:
 
366
  return f"Error changing model: {str(e)}"
367
 
368
  # --- Main Prediction Function (ONNX) ---
 
369
  def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
370
  print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
371
 
 
395
  providers.append('CUDAExecutionProvider')
396
  providers.append('CPUExecutionProvider')
397
  print(f"Attempting to load session with providers: {providers}")
398
+ session = g_session
399
  print(f"ONNX session loaded using: {session.get_providers()[0]}")
400
  except Exception as e:
401
  message = f"Error loading ONNX session in worker: {e}"
 
568
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
569
  # Initialize paths and labels at startup (with default model)
570
  initialize_onnx_paths(DEFAULT_MODEL)
571
+ demo.queue(concurrency_count=1, max_size=None, status_update_rate=1)
572
  # Launch Gradio app
573
  demo.launch()