onnx conversion + 16-bit quant for https://github.com/danbooru/autotagger
download tags here: https://github.com/danbooru/autotagger/blob/master/data/tags.json
Converted with https://github.com/tkeyo/fastai-onnx/blob/main/fastai_to_onnx.ipynb
Usage example (python source code for inference):
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "onnxruntime",
# "pillow",
# "tqdm",
# ]
# ///
import argparse
from PIL import Image
import onnxruntime as rt
import numpy as np
import json
from tqdm import tqdm
def resize_and_pad(image_path, size: int, typ=np.float32):
# Open the image
img = Image.open(image_path)
# Resize the image while maintaining aspect ratio
img.thumbnail((size, size), Image.LANCZOS)
# Create a new image with a black background (zero-filled)
new_image = Image.new("RGB", (size, size), (0, 0, 0))
# Calculate the position to paste the resized image
paste_x = (size - img.width) // 2
paste_y = (size - img.height) // 2
# Paste the resized image onto the new image
new_image.paste(img, (paste_x, paste_y))
# Convert the image to a NumPy array
image_array = np.array(new_image)
image = image_array.transpose(2,0,1).astype(typ)
image /= 255
return image
def sort_and_yield_indices(arr, thresh=0.1):
# Sort the array in descending order and get the sorted indices
sorted_indices = np.argsort(arr)[::-1]
sorted_arr = arr[sorted_indices]
# Yield indices of elements that are >= 0.0001
for index in sorted_indices:
if arr[index] < thresh:
break
yield (index,arr[index])
def window(array, size):
for i in tqdm(range(0, len(array), size)):
yield array[i : i + size]
# import line_profiler
# @line_profiler.profile
def main():
parser = argparse.ArgumentParser(description='Process some images.')
parser.add_argument('-m', '--model-path', required=True, help='Path to the model.')
parser.add_argument('-q', '--quant', type=int, required=True, choices=[16, 32], help='Quantization size (16 or 32).')
parser.add_argument('-b', '--batch-size', type=int, default=16, help='Batch size')
parser.add_argument('-tags', '--tag-path', required=True, help='Path to the tags.')
parser.add_argument('-thresh', type=float, help='Tag threshold', default=0.1)
parser.add_argument('-d', '--data-path', type=str, help='Data dump file')
parser.add_argument('--replace-path', type=str, help='replace path (i.e. `from:to`)')
parser.add_argument('images', nargs='+', type=str, help='Image to be processed')
args = parser.parse_args()
replace_path_from, replace_path_to = None, None
if args.replace_path:
replace_path_from, replace_path_to = args.replace_path.split(':')
import csv
excluded = set()
if args.data_path:
with open(args.data_path, 'r') as f:
csvreader = csv.reader(f, escapechar='\\', quoting=csv.QUOTE_NONE)
for row in csvreader:
if replace_path_from:
excluded.add(row[0].replace(replace_path_from, replace_path_to))
else:
excluded.add(row[0])
args.images = list(filter(lambda x: x not in excluded, args.images))
print(f"{len(args.images)} images to be processed")
if not args.images:
return
vocab = json.load(open(args.tag_path,"r"))
if args.quant == 16:
typ = np.float16
elif args.quant == 32:
typ = np.float32
img_size = 224
# initialize onnx runtime inference session
sess_opt = rt.SessionOptions()
sess = rt.InferenceSession(args.model_path, sess_opt)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
import concurrent.futures
if args.data_path:
outf = open(args.data_path, 'a')
else:
outf = None
for images in window(args.images, args.batch_size):
with concurrent.futures.ThreadPoolExecutor() as executor:
inp_tensors = np.array(list(executor.map(lambda img: resize_and_pad(img, img_size,typ=typ), images)))
results = sess.run([output_name], {input_name: inp_tensors})[0]
if outf:
for (K, path) in zip(results, images):
if replace_path_from:
path=path.replace(replace_path_to,replace_path_from)
outf.write(path.replace(",", "\\,"))
for index, v in sort_and_yield_indices(K, args.thresh):
outf.write(",")
outf.write(vocab[index])
outf.write("\n")
else:
for (K, path) in zip(results, images):
print(path)
for index, v in sort_and_yield_indices(K, args.thresh):
print(f" {vocab[index]}:{v}")
if __name__ == '__main__':
main()