""" ONNX-based batch image processing for the Image Tagger application. """ import os import json import time import traceback import numpy as np import glob import onnxruntime as ort from PIL import Image import torchvision.transforms as transforms from concurrent.futures import ThreadPoolExecutor def preprocess_image(image_path, image_size=512): """Process an image for inference""" if not os.path.exists(image_path): raise ValueError(f"Image not found at path: {image_path}") # Initialize transform transform = transforms.Compose([ transforms.ToTensor(), ]) try: with Image.open(image_path) as img: # Convert RGBA or Palette images to RGB if img.mode in ('RGBA', 'P'): img = img.convert('RGB') # Get original dimensions width, height = img.size aspect_ratio = width / height # Calculate new dimensions to maintain aspect ratio if aspect_ratio > 1: new_width = image_size new_height = int(new_width / aspect_ratio) else: new_height = image_size new_width = int(new_height * aspect_ratio) # Resize with LANCZOS filter img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create new image with padding new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0)) paste_x = (image_size - new_width) // 2 paste_y = (image_size - new_height) // 2 new_image.paste(img, (paste_x, paste_y)) # Apply transforms img_tensor = transform(new_image) return img_tensor.numpy() except Exception as e: raise Exception(f"Error processing {image_path}: {str(e)}") def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall", active_threshold=0.35, active_category_thresholds=None, min_confidence=0.1): """ Process a single image using ONNX model Args: image_path: Path to the image file model_path: Path to the ONNX model file metadata: Model metadata dictionary threshold_profile: The threshold profile being used active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds min_confidence: Minimum confidence to include in results Returns: Dictionary with tags and probabilities """ import time try: # Create ONNX tagger for this image (or reuse an existing one) if hasattr(process_single_image_onnx, 'tagger'): tagger = process_single_image_onnx.tagger else: # Get metadata path from model_path metadata_path = model_path.replace('.onnx', '_metadata.json') if not os.path.exists(metadata_path): metadata_path = model_path.replace('.onnx', '') + '_metadata.json' # Create new tagger tagger = ONNXImageTagger(model_path, metadata_path) # Cache it for future calls process_single_image_onnx.tagger = tagger # Preprocess the image start_time = time.time() img_array = preprocess_image(image_path) # Run inference results = tagger.predict_batch( [img_array], threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) inference_time = time.time() - start_time if results: result = results[0] result['inference_time'] = inference_time return result else: return { 'success': False, 'error': 'Failed to process image', 'all_tags': [], 'all_probs': {}, 'tags': {} } except Exception as e: import traceback print(f"Error in process_single_image_onnx: {str(e)}") traceback.print_exc() return { 'success': False, 'error': str(e), 'all_tags': [], 'all_probs': {}, 'tags': {} } def preprocess_images_parallel(image_paths, image_size=512, max_workers=8): """Process multiple images in parallel""" processed_images = [] valid_paths = [] # Define a worker function def process_single_image(path): try: return preprocess_image(path, image_size), path except Exception as e: print(f"Error processing {path}: {str(e)}") return None, path # Process images in parallel with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(process_single_image, image_paths)) # Filter results for img_array, path in results: if img_array is not None: processed_images.append(img_array) valid_paths.append(path) return processed_images, valid_paths def apply_category_limits(result, category_limits): """ Apply category limits to a result dictionary. Args: result: Result dictionary containing tags and all_tags category_limits: Dictionary mapping categories to their tag limits (0 = exclude category, -1 = no limit/include all) Returns: Updated result dictionary with limits applied """ if not category_limits or not result['success']: return result # Get the filtered tags filtered_tags = result['tags'] # Apply limits to each category for category, cat_tags in list(filtered_tags.items()): # Get limit for this category, default to -1 (no limit) limit = category_limits.get(category, -1) if limit == 0: # Exclude this category entirely del filtered_tags[category] elif limit > 0 and len(cat_tags) > limit: # Limit to top N tags for this category filtered_tags[category] = cat_tags[:limit] # Regenerate all_tags list after applying limits all_tags = [] for category, cat_tags in filtered_tags.items(): for tag, _ in cat_tags: all_tags.append(tag) # Update the result with limited tags result['tags'] = filtered_tags result['all_tags'] = all_tags return result class ONNXImageTagger: """ONNX-based image tagger for fast batch inference""" def __init__(self, model_path, metadata_path): # Load model self.model_path = model_path try: self.session = ort.InferenceSession( model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) print(f"Using providers: {self.session.get_providers()}") except Exception as e: print(f"CUDA not available, using CPU: {e}") self.session = ort.InferenceSession( model_path, providers=['CPUExecutionProvider'] ) print(f"Using providers: {self.session.get_providers()}") # Load metadata with open(metadata_path, 'r') as f: self.metadata = json.load(f) # Get input name self.input_name = self.session.get_inputs()[0].name print(f"Model loaded successfully. Input name: {self.input_name}") def predict_batch(self, image_arrays, threshold=0.325, category_thresholds=None, min_confidence=0.1): """Run batch inference on preprocessed image arrays""" # Stack arrays into batch batch_input = np.stack(image_arrays) # Run inference start_time = time.time() outputs = self.session.run(None, {self.input_name: batch_input}) inference_time = time.time() - start_time print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)") # Process outputs initial_probs = 1.0 / (1.0 + np.exp(-outputs[0])) # Apply sigmoid refined_probs = 1.0 / (1.0 + np.exp(-outputs[1])) if len(outputs) > 1 else initial_probs # Apply thresholds and extract tags for each image batch_results = [] for i in range(refined_probs.shape[0]): probs = refined_probs[i] # Extract and organize all probabilities all_probs = {} for idx in range(probs.shape[0]): prob_value = float(probs[idx]) if prob_value >= min_confidence: idx_str = str(idx) tag_name = self.metadata['idx_to_tag'].get(idx_str, f"unknown-{idx}") category = self.metadata['tag_to_category'].get(tag_name, "general") if category not in all_probs: all_probs[category] = [] all_probs[category].append((tag_name, prob_value)) # Sort tags by probability within each category for category in all_probs: all_probs[category] = sorted( all_probs[category], key=lambda x: x[1], reverse=True ) # Get the filtered tags based on the selected threshold tags = {} for category, cat_tags in all_probs.items(): # Use category-specific threshold if available if category_thresholds and category in category_thresholds: cat_threshold = category_thresholds[category] else: cat_threshold = threshold tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold] # Create a flat list of all tags above threshold all_tags = [] for category, cat_tags in tags.items(): for tag, _ in cat_tags: all_tags.append(tag) batch_results.append({ 'tags': tags, 'all_probs': all_probs, 'all_tags': all_tags, 'success': True }) return batch_results def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile, active_threshold, active_category_thresholds, save_dir=None, progress_callback=None, min_confidence=0.1, batch_size=16, category_limits=None): """ Process all images in a folder using the ONNX model. Args: folder_path: Path to folder containing images model_path: Path to the ONNX model file metadata_path: Path to the model metadata file threshold_profile: Selected threshold profile active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds save_dir: Directory to save tag files (if None uses default) progress_callback: Optional callback for progress updates min_confidence: Minimum confidence threshold batch_size: Number of images to process at once category_limits: Dictionary mapping categories to their tag limits (0 = unlimited) Returns: Dictionary with results for each image """ from utils.file_utils import save_tags_to_file # Import here to avoid circular imports # Find all image files in the folder image_extensions = ['*.jpg', '*.jpeg', '*.png'] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(folder_path, ext))) image_files.extend(glob.glob(os.path.join(folder_path, ext.upper()))) # Use a set to remove duplicate files (Windows filesystems are case-insensitive) if os.name == 'nt': # Windows # Use lowercase paths for comparison on Windows unique_paths = set() unique_files = [] for file_path in image_files: normalized_path = os.path.normpath(file_path).lower() if normalized_path not in unique_paths: unique_paths.add(normalized_path) unique_files.append(file_path) image_files = unique_files if not image_files: return { 'success': False, 'error': f"No images found in {folder_path}", 'results': {} } # Use the provided save directory or create a default one if save_dir is None: app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) save_dir = os.path.join(app_dir, "saved_tags") # Ensure the directory exists os.makedirs(save_dir, exist_ok=True) # Create ONNX tagger tagger = ONNXImageTagger(model_path, metadata_path) # Process images in batches results = {} total_images = len(image_files) processed = 0 start_time = time.time() # Process in batches for i in range(0, total_images, batch_size): batch_start = time.time() # Get current batch of images batch_files = image_files[i:i+batch_size] batch_size_actual = len(batch_files) # Update progress if callback provided if progress_callback: progress_callback(processed, total_images, batch_files[0] if batch_files else None) print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images") try: # Preprocess images in parallel processed_images, valid_paths = preprocess_images_parallel(batch_files) if processed_images: # Run batch prediction batch_results = tagger.predict_batch( processed_images, threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) # Process results for each image for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)): # Update progress if callback provided if progress_callback: progress_callback(processed + j, total_images, image_path) # Debug print to track what's happening print(f"Before limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags") print(f"Category limits applied: {category_limits}") # Make sure we apply limits right before saving if category_limits and result['success']: # Before counts for debugging before_counts = {cat: len(tags) for cat, tags in result['tags'].items()} # Apply the limits result = apply_category_limits(result, category_limits) # After counts for debugging after_counts = {cat: len(tags) for cat, tags in result['tags'].items()} # Print the effect of limits print(f"Before limits: {before_counts}") print(f"After limits: {after_counts}") print(f"After limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags") # Save the tags to a file if result['success']: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) # Store the result results[image_path] = result processed += batch_size_actual # Calculate batch timing batch_end = time.time() batch_time = batch_end - batch_start print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)") except Exception as e: print(f"Error processing batch: {str(e)}") traceback.print_exc() # Process failed images one by one as fallback for image_path in batch_files: try: # Update progress if callback provided if progress_callback: progress_callback(processed + j, total_images, image_path) # Debug print to track what's happening print(f"Before limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags") print(f"Category limits applied: {category_limits}") # Make sure we apply limits right before saving if category_limits and result['success']: # Before counts for debugging before_counts = {cat: len(tags) for cat, tags in result['tags'].items()} # Apply the limits result = apply_category_limits(result, category_limits) # After counts for debugging after_counts = {cat: len(tags) for cat, tags in result['tags'].items()} # Print the effect of limits print(f"Before limits: {before_counts}") print(f"After limits: {after_counts}") print(f"After limiting - Tags for {os.path.basename(image_path)}: {len(result['all_tags'])} tags") # Preprocess single image img_array = preprocess_image(image_path) # Run inference on single image single_results = tagger.predict_batch( [img_array], threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) if single_results: result = single_results[0] # Save the tags to a file if result['success']: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True # Add this to be consistent ) result['output_path'] = str(output_path) # Store the result results[image_path] = result else: results[image_path] = { 'success': False, 'error': 'Failed to process image', 'all_tags': [] } except Exception as img_e: print(f"Error processing single image {image_path}: {str(img_e)}") results[image_path] = { 'success': False, 'error': str(img_e), 'all_tags': [] } processed += 1 # Final progress update if progress_callback: progress_callback(total_images, total_images, None) end_time = time.time() total_time = end_time - start_time print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image") return { 'success': True, 'total': total_images, 'processed': len(results), 'results': results, 'save_dir': save_dir, 'time_elapsed': end_time - start_time }