onnx conversion + 16-bit quant for https://github.com/danbooru/autotagger

download tags here: https://github.com/danbooru/autotagger/blob/master/data/tags.json

Converted with https://github.com/tkeyo/fastai-onnx/blob/main/fastai_to_onnx.ipynb

Usage example (python source code for inference):
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "onnxruntime",
#     "pillow",
#     "tqdm",
# ]
# ///
import argparse
from PIL import Image
import onnxruntime as rt
import numpy as np
import json
from tqdm import tqdm

def resize_and_pad(image_path, size: int, typ=np.float32):
    # Open the image
    img = Image.open(image_path)
    # Resize the image while maintaining aspect ratio
    img.thumbnail((size, size), Image.LANCZOS)
    # Create a new image with a black background (zero-filled)
    new_image = Image.new("RGB", (size, size), (0, 0, 0))
    # Calculate the position to paste the resized image
    paste_x = (size - img.width) // 2
    paste_y = (size - img.height) // 2
    # Paste the resized image onto the new image
    new_image.paste(img, (paste_x, paste_y))
    # Convert the image to a NumPy array
    image_array = np.array(new_image)
    image = image_array.transpose(2,0,1).astype(typ)
    image /= 255
    return image

def sort_and_yield_indices(arr, thresh=0.1):
    # Sort the array in descending order and get the sorted indices
    sorted_indices = np.argsort(arr)[::-1]
    sorted_arr = arr[sorted_indices]

    # Yield indices of elements that are >= 0.0001
    for index in sorted_indices:
        if arr[index] < thresh:
            break
        yield (index,arr[index])

def window(array, size):
    for i in tqdm(range(0, len(array), size)):
        yield array[i : i + size]

# import line_profiler
# @line_profiler.profile
def main():
    parser = argparse.ArgumentParser(description='Process some images.')

    parser.add_argument('-m', '--model-path', required=True, help='Path to the model.')
    parser.add_argument('-q', '--quant', type=int, required=True, choices=[16, 32], help='Quantization size (16 or 32).')
    parser.add_argument('-b', '--batch-size', type=int, default=16, help='Batch size')
    parser.add_argument('-tags', '--tag-path', required=True, help='Path to the tags.')
    parser.add_argument('-thresh', type=float, help='Tag threshold', default=0.1)
    parser.add_argument('-d', '--data-path', type=str, help='Data dump file')
    parser.add_argument('--replace-path', type=str, help='replace path (i.e. `from:to`)')

    parser.add_argument('images', nargs='+', type=str, help='Image to be processed')

    args = parser.parse_args()

    replace_path_from, replace_path_to = None, None
    if args.replace_path:
        replace_path_from, replace_path_to = args.replace_path.split(':')

    import csv
    excluded = set()
    if args.data_path:
        with open(args.data_path, 'r') as f:
            csvreader = csv.reader(f, escapechar='\\', quoting=csv.QUOTE_NONE)
            for row in csvreader:
                if replace_path_from:
                    excluded.add(row[0].replace(replace_path_from, replace_path_to))
                else:
                    excluded.add(row[0])
    args.images = list(filter(lambda x: x not in excluded, args.images))
    print(f"{len(args.images)} images to be processed")
    if not args.images:
        return

    vocab = json.load(open(args.tag_path,"r"))
    if args.quant == 16:
        typ = np.float16
    elif args.quant == 32:
        typ = np.float32
    img_size = 224

    # initialize onnx runtime inference session
    sess_opt = rt.SessionOptions()
    sess = rt.InferenceSession(args.model_path, sess_opt)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    import concurrent.futures

    if args.data_path:
        outf = open(args.data_path, 'a')
    else:
        outf = None
    for images in window(args.images, args.batch_size):
        with concurrent.futures.ThreadPoolExecutor() as executor:
            inp_tensors = np.array(list(executor.map(lambda img: resize_and_pad(img, img_size,typ=typ), images)))
        results = sess.run([output_name], {input_name: inp_tensors})[0]

        if outf:
            for (K, path) in zip(results, images):
                if replace_path_from:
                    path=path.replace(replace_path_to,replace_path_from)
                outf.write(path.replace(",", "\\,"))
                for index, v in sort_and_yield_indices(K, args.thresh):
                    outf.write(",")
                    outf.write(vocab[index])
                outf.write("\n")
        else:
            for (K, path) in zip(results, images):
                print(path)
                for index, v in sort_and_yield_indices(K, args.thresh):
                    print(f" {vocab[index]}:{v}")

if __name__ == '__main__':
    main()

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support