DraconicDragon commited on
Commit
bd30518
·
verified ·
1 Parent(s): 83704ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -57
app.py CHANGED
@@ -310,11 +310,13 @@ g_current_model = None
310
 
311
  # --- Global ONNX session ---
312
  g_session = None
 
 
313
 
314
  # --- Initialization Function ---
315
  def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
316
  global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
317
- global g_session
318
 
319
  if not model_choice in MODEL_OPTIONS:
320
  print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
@@ -325,7 +327,7 @@ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
325
  onnx_filename = MODEL_OPTIONS[model_choice]
326
  tag_mapping_filename = f"{model_dir}/tag_mapping.json"
327
 
328
- print(f"Initializing ONNX paths and labels for model: {model_choice}...")
329
  hf_token = os.environ.get("HF_TOKEN")
330
 
331
  try:
@@ -353,13 +355,24 @@ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
353
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
354
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
355
 
356
- # Load ONNX session ONCE here
357
- print("Creating ONNX Runtime session (CPUExecutionProvider)...")
358
- g_session = ort.InferenceSession(
359
- g_onnx_model_path,
360
- providers=["CPUExecutionProvider"]
361
- )
362
- print("ONNX Runtime session ready.")
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  return True
365
 
@@ -373,9 +386,29 @@ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
373
  g_idx_to_tag = None
374
  g_tag_to_category = None
375
  g_current_model = None
376
- # Raise Gradio error to make it visible in the UI
 
 
377
  raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  # Function to handle model change
380
  def change_model(model_choice):
381
  try:
@@ -388,8 +421,10 @@ def change_model(model_choice):
388
  return f"Error changing model: {str(e)}"
389
 
390
  # --- Main Prediction Function (ONNX) ---
 
391
  def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
392
- print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
 
393
 
394
  # Ensure current model matches selected model
395
  global g_current_model
@@ -404,25 +439,12 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
404
  if g_onnx_model_path is None or g_labels_data is None:
405
  message = "Error: Paths or labels not initialized. Check startup logs."
406
  print(message)
407
- # Return error message and None for the image output
408
  return message, None
409
 
410
- # --- 2. Load ONNX Session (inside worker) ---
411
- session = None
412
- try:
413
- print(f"Loading ONNX session from: {g_onnx_model_path}")
414
- available_providers = ort.get_available_providers()
415
- providers = []
416
- if 'CUDAExecutionProvider' in available_providers:
417
- providers.append('CUDAExecutionProvider')
418
- providers.append('CPUExecutionProvider')
419
- print(f"Attempting to load session with providers: {providers}")
420
- session = g_session
421
- print(f"ONNX session loaded using: {session.get_providers()[0]}")
422
- except Exception as e:
423
- message = f"Error loading ONNX session in worker: {e}"
424
  print(message)
425
- import traceback; traceback.print_exc()
426
  return message, None
427
 
428
  # --- 3. Process Input Image ---
@@ -433,26 +455,23 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
433
  try:
434
  # Handle different input types (PIL, numpy, URL, file path)
435
  if isinstance(image_input, str):
436
- if image_input.startswith("http"): # URL
437
  response = requests.get(image_input, timeout=10)
438
  response.raise_for_status()
439
  image = Image.open(io.BytesIO(response.content))
440
- elif os.path.exists(image_input): # File path
441
  image = Image.open(image_input)
442
  else:
443
- raise ValueError(f"Invalid image input string: {image_input}")
444
  elif isinstance(image_input, np.ndarray):
445
- image = Image.fromarray(image_input)
446
  elif isinstance(image_input, Image.Image):
447
- image = image_input # Already a PIL image
448
  else:
449
- raise TypeError(f"Unsupported image input type: {type(image_input)}")
450
 
451
  # Preprocess the PIL image
452
  original_pil_image, input_tensor = preprocess_image(image)
453
-
454
- # Ensure input tensor is float32, as expected by most ONNX models
455
- # (even if the model internally uses float16)
456
  input_tensor = input_tensor.astype(np.float32)
457
 
458
  except Exception as e:
@@ -462,49 +481,51 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
462
 
463
  # --- 4. Run Inference ---
464
  try:
465
- input_name = session.get_inputs()[0].name
466
- output_name = session.get_outputs()[0].name
467
- print(f"Running inference with input '{input_name}', output '{output_name}'")
468
  start_time = time.time()
469
- outputs = session.run([output_name], {input_name: input_tensor})[0]
 
 
 
 
 
 
 
 
 
 
470
  inference_time = time.time() - start_time
471
- print(f"Inference completed in {inference_time:.3f} seconds")
472
 
473
  # Check for NaN/Inf in outputs
474
  if np.isnan(outputs).any() or np.isinf(outputs).any():
475
  print("Warning: NaN or Inf detected in model output. Clamping...")
476
- outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
477
 
478
- # Apply sigmoid (outputs are likely logits)
479
- # Use a stable sigmoid implementation
480
  def stable_sigmoid(x):
481
- return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
482
- probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
483
 
484
  except Exception as e:
485
- message = f"Error during ONNX inference: {e}"
486
  print(message)
487
  import traceback; traceback.print_exc()
488
  return message, None
489
- finally:
490
- # Clean up session if needed (might reduce memory usage between clicks)
491
- del session
492
 
493
  # --- 5. Post-process and Format Output ---
494
  try:
495
  print("Post-processing results...")
496
- # Use the correct global variable for labels
497
  predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
498
 
499
  # Format output text string
500
  output_tags = []
501
  if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
502
  if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
503
- # Add other categories, respecting order and filtering meta if needed
504
  for category in ["artist", "character", "copyright", "general", "meta", "model"]:
505
  tags_in_category = predictions.get(category, [])
506
  for tag, prob in tags_in_category:
507
- # Basic meta tag filtering for text output
508
  if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
509
  continue
510
  output_tags.append(tag.replace("_", " "))
@@ -514,12 +535,8 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
514
  viz_image = None
515
  if output_mode == "Tags + Visualization":
516
  print("Generating visualization...")
517
- # Pass the correct threshold for display title (can pass both if needed)
518
- # For simplicity, passing gen_threshold as a representative value
519
  viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
520
  print("Visualization generated.")
521
- else:
522
- print("Visualization skipped.")
523
 
524
  print("Prediction complete.")
525
  return output_text, viz_image
@@ -540,6 +557,7 @@ footer { display: none !important; }
540
 
541
  with gr.Blocks(css=css) as demo:
542
  gr.Markdown("# CL EVA02 ONNX Tagger (CPU)")
 
543
  gr.Markdown("This space is a duplicate of https://huggingface.co/spaces/cella110n/cl_tagger running on CPU and uses the [non-gated releases](https://huggingface.co/cella110n/cl_tagger) of cl-tagger.")
544
  gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
545
 
 
310
 
311
  # --- Global ONNX session ---
312
  g_session = None
313
+ g_use_openvino = False
314
+ g_execution_provider = None
315
 
316
  # --- Initialization Function ---
317
  def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
318
  global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
319
+ global g_session, g_use_openvino, g_execution_provider
320
 
321
  if not model_choice in MODEL_OPTIONS:
322
  print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
 
327
  onnx_filename = MODEL_OPTIONS[model_choice]
328
  tag_mapping_filename = f"{model_dir}/tag_mapping.json"
329
 
330
+ print(f"Initializing paths and labels for model: {model_choice}...")
331
  hf_token = os.environ.get("HF_TOKEN")
332
 
333
  try:
 
355
  g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
356
  print(f"Labels loaded. Count: {len(g_labels_data.names)}")
357
 
358
+ # Try OpenVINO first, then fall back to ONNX Runtime
359
+ print("Attempting to initialize inference runtime...")
360
+ try:
361
+ import openvino as ov
362
+
363
+ print("OpenVINO available, attempting to load model...")
364
+ core = ov.Core()
365
+ model = core.read_model(g_onnx_model_path)
366
+ g_session = core.compile_model(model, "CPU")
367
+ g_use_openvino = True
368
+ g_execution_provider = "CPU – OpenVINO™"
369
+ print("Successfully initialized with OpenVINO runtime")
370
+ except ImportError:
371
+ print("OpenVINO not available, falling back to ONNX Runtime CPU")
372
+ _init_onnx_runtime()
373
+ except Exception as e:
374
+ print(f"OpenVINO initialization failed: {e}, falling back to ONNX Runtime CPU")
375
+ _init_onnx_runtime()
376
 
377
  return True
378
 
 
386
  g_idx_to_tag = None
387
  g_tag_to_category = None
388
  g_current_model = None
389
+ g_session = None
390
+ g_use_openvino = False
391
+ g_execution_provider = None
392
  raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
393
 
394
+ def _init_onnx_runtime():
395
+ """Initialize ONNX Runtime with CPU as fallback"""
396
+ global g_session, g_use_openvino, g_execution_provider
397
+
398
+ sess_options = ort.SessionOptions()
399
+ sess_options.log_severity_level = 3 # Only show errors
400
+
401
+ providers = ["CPUExecutionProvider"]
402
+ g_session = ort.InferenceSession(
403
+ g_onnx_model_path,
404
+ sess_options=sess_options,
405
+ providers=providers
406
+ )
407
+ g_use_openvino = False
408
+ g_execution_provider = g_session.get_providers()[0]
409
+ print(f"ONNX Runtime session ready with {g_execution_provider}")
410
+
411
+
412
  # Function to handle model change
413
  def change_model(model_choice):
414
  try:
 
421
  return f"Error changing model: {str(e)}"
422
 
423
  # --- Main Prediction Function (ONNX) ---
424
+ # --- Main Prediction Function (ONNX/OpenVINO) ---
425
  def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
426
+ print(f"--- predict_onnx function started with model {model_choice} ---")
427
+ print(f"Using runtime: {g_execution_provider}")
428
 
429
  # Ensure current model matches selected model
430
  global g_current_model
 
439
  if g_onnx_model_path is None or g_labels_data is None:
440
  message = "Error: Paths or labels not initialized. Check startup logs."
441
  print(message)
 
442
  return message, None
443
 
444
+ # --- 2. Check session is available ---
445
+ if g_session is None:
446
+ message = "Error: Inference session not initialized."
 
 
 
 
 
 
 
 
 
 
 
447
  print(message)
 
448
  return message, None
449
 
450
  # --- 3. Process Input Image ---
 
455
  try:
456
  # Handle different input types (PIL, numpy, URL, file path)
457
  if isinstance(image_input, str):
458
+ if image_input.startswith("http"):
459
  response = requests.get(image_input, timeout=10)
460
  response.raise_for_status()
461
  image = Image.open(io.BytesIO(response.content))
462
+ elif os.path.exists(image_input):
463
  image = Image.open(image_input)
464
  else:
465
+ raise ValueError(f"Invalid image input string: {image_input}")
466
  elif isinstance(image_input, np.ndarray):
467
+ image = Image.fromarray(image_input)
468
  elif isinstance(image_input, Image.Image):
469
+ image = image_input
470
  else:
471
+ raise TypeError(f"Unsupported image input type: {type(image_input)}")
472
 
473
  # Preprocess the PIL image
474
  original_pil_image, input_tensor = preprocess_image(image)
 
 
 
475
  input_tensor = input_tensor.astype(np.float32)
476
 
477
  except Exception as e:
 
481
 
482
  # --- 4. Run Inference ---
483
  try:
484
+ print(f"Running inference with {'OpenVINO' if g_use_openvino else 'ONNX Runtime'}")
 
 
485
  start_time = time.time()
486
+
487
+ if g_use_openvino:
488
+ # OpenVINO inference
489
+ results = g_session(input_tensor)
490
+ outputs = list(results.values())[0]
491
+ else:
492
+ # ONNX Runtime inference
493
+ input_name = g_session.get_inputs()[0].name
494
+ output_name = g_session.get_outputs()[0].name
495
+ outputs = g_session.run([output_name], {input_name: input_tensor})[0]
496
+
497
  inference_time = time.time() - start_time
498
+ print(f"Inference completed in {inference_time:.3f} seconds using {g_execution_provider}")
499
 
500
  # Check for NaN/Inf in outputs
501
  if np.isnan(outputs).any() or np.isinf(outputs).any():
502
  print("Warning: NaN or Inf detected in model output. Clamping...")
503
+ outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0)
504
 
505
+ # Apply sigmoid
 
506
  def stable_sigmoid(x):
507
+ return 1 / (1 + np.exp(-np.clip(x, -30, 30)))
508
+ probs = stable_sigmoid(outputs[0])
509
 
510
  except Exception as e:
511
+ message = f"Error during inference: {e}"
512
  print(message)
513
  import traceback; traceback.print_exc()
514
  return message, None
 
 
 
515
 
516
  # --- 5. Post-process and Format Output ---
517
  try:
518
  print("Post-processing results...")
 
519
  predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
520
 
521
  # Format output text string
522
  output_tags = []
523
  if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
524
  if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
525
+
526
  for category in ["artist", "character", "copyright", "general", "meta", "model"]:
527
  tags_in_category = predictions.get(category, [])
528
  for tag, prob in tags_in_category:
 
529
  if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
530
  continue
531
  output_tags.append(tag.replace("_", " "))
 
535
  viz_image = None
536
  if output_mode == "Tags + Visualization":
537
  print("Generating visualization...")
 
 
538
  viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
539
  print("Visualization generated.")
 
 
540
 
541
  print("Prediction complete.")
542
  return output_text, viz_image
 
557
 
558
  with gr.Blocks(css=css) as demo:
559
  gr.Markdown("# CL EVA02 ONNX Tagger (CPU)")
560
+ gr.Markdown("OpenVINO™ is used for accelerated CPU inference when available, with ONNX Runtime as fallback.")
561
  gr.Markdown("This space is a duplicate of https://huggingface.co/spaces/cella110n/cl_tagger running on CPU and uses the [non-gated releases](https://huggingface.co/cella110n/cl_tagger) of cl-tagger.")
562
  gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
563