import onnxruntime as ort import numpy as np import json from PIL import Image def preprocess_image(img_path, target_size=512, keep_aspect=True): """ Load an image from img_path, convert to RGB, and resize/pad to (target_size, target_size). Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array. """ img = Image.open(img_path).convert("RGB") if keep_aspect: # Preserve aspect ratio, pad black w, h = img.size aspect = w / h if aspect > 1: new_w = target_size new_h = int(new_w / aspect) else: new_h = target_size new_w = int(new_h * aspect) # Resize with Lanczos img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) # Pad to a square background = Image.new("RGB", (target_size, target_size), (0, 0, 0)) paste_x = (target_size - new_w) // 2 paste_y = (target_size - new_h) // 2 background.paste(img, (paste_x, paste_y)) img = background else: # simple direct resize to 512x512 img = img.resize((target_size, target_size), Image.Resampling.LANCZOS) # Convert to numpy array arr = np.array(img).astype("float32") / 255.0 # scale to [0,1] # Transpose from HWC -> CHW arr = np.transpose(arr, (2, 0, 1)) # Add batch dimension: (1,3,512,512) arr = np.expand_dims(arr, axis=0) return arr # Example input def load_thresholds(threshold_json_path, mode="balanced"): """ Loads thresholds from the given JSON file, using a particular mode (e.g. 'balanced', 'high_precision', 'high_recall') for each category. Returns: thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... } fallback_threshold (float): The overall threshold if category not found """ with open(threshold_json_path, "r", encoding="utf-8") as f: data = json.load(f) # The fallback threshold from the "overall" section for the chosen mode fallback_threshold = data["overall"][mode]["threshold"] # Build a dict of thresholds keyed by category thresholds_by_category = {} if "categories" in data: for cat_name, cat_modes in data["categories"].items(): # If the chosen mode is present for that category, use it; # otherwise fall back to the "overall" threshold. if mode in cat_modes and "threshold" in cat_modes[mode]: thresholds_by_category[cat_name] = cat_modes[mode]["threshold"] else: thresholds_by_category[cat_name] = fallback_threshold return thresholds_by_category, fallback_threshold def onnx_inference( img_paths, onnx_path="camie_refined_no_flash.onnx", metadata_file="metadata.json", threshold_json_path="thresholds.json", mode="balanced", target_size=512, keep_aspect=True ): """ Loads the ONNX model, runs inference on a list of image paths, and applies category-wise thresholds from threshold.json (per the chosen mode). Args: img_paths : List of paths to images. onnx_path : Path to the exported ONNX model file. metadata_file : Path to metadata.json that contains idx_to_tag, tag_to_category, etc. threshold_json_path : Path to thresholds.json containing category-wise threshold info. mode : "balanced", "high_precision", or "high_recall". target_size : Final size of preprocessed images (512 by default). keep_aspect : If True, preserve aspect ratio when resizing, pad with black. Returns: A list of dicts, one per input image, each containing: { "initial_logits": np.ndarray of shape (N_tags,), "refined_logits": np.ndarray of shape (N_tags,), "predicted_indices": list of tag indices that exceeded threshold, "predicted_tags": list of predicted tag strings, ... } """ # 1) Initialize ONNX runtime session session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) # For GPU usage, you could do e.g.: # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"]) # 2) Pre-load metadata with open(metadata_file, "r", encoding="utf-8") as f: metadata = json.load(f) idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... } tag_to_category = metadata.get("tag_to_category", {}) # Load thresholds from thresholds.json using the specified mode thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode) # 3) Preprocess each image into a batch batch_tensors = [] for img_path in img_paths: x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect) batch_tensors.append(x) # Concatenate along the batch dimension => shape (batch_size, 3, H, W) batch_input = np.concatenate(batch_tensors, axis=0) # 4) Run inference input_name = session.get_inputs()[0].name # typically "image" or "input" outputs = session.run(None, {input_name: batch_input}) # Typically we get [initial_tags, refined_tags] as output initial_preds, refined_preds = outputs # shapes => (batch_size, N_tags) # 5) Convert logits -> probabilities -> apply category-specific thresholds batch_results = [] for i in range(initial_preds.shape[0]): init_logit = initial_preds[i, :] # shape (N_tags,) ref_logit = refined_preds[i, :] # shape (N_tags,) ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # shape (N_tags,) predicted_indices = [] predicted_tags = [] # Check each tag against the category threshold for idx in range(ref_logit.shape[0]): tag_name = idx_to_tag[str(idx)] # Convert index->string->tag name category = tag_to_category.get(tag_name, "general") # fallback to "general" if missing cat_threshold = thresholds_by_category.get(category, fallback_threshold) if ref_prob[idx] >= cat_threshold: predicted_indices.append(idx) predicted_tags.append(tag_name) # Build result for this image result_dict = { "initial_logits": init_logit, "refined_logits": ref_logit, "predicted_indices": predicted_indices, "predicted_tags": predicted_tags, } batch_results.append(result_dict) return batch_results if __name__ == "__main__": # Example usage images = ["images.png"] results = onnx_inference( img_paths=images, onnx_path="camie_refined_no_flash_v15.onnx", metadata_file="metadata.json", threshold_json_path="thresholds.json", mode="balanced", # or "balanced", "high_precision" target_size=512, keep_aspect=True ) for i, res in enumerate(results): print(f"Image: {images[i]}") print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}") # Show first 10 predicted tags (if available) sample_tags = res['predicted_tags'] print(" Sample predicted tags:", sample_tags) print()