| import gradio as gr |
| import huggingface_hub |
| from PIL import Image |
| from pathlib import Path |
| import onnxruntime as rt |
| import numpy as np |
| import csv |
|
|
|
|
| MODEL_REPO = 'toynya/Z3D-E621-Convnext' |
| THRESHOLD = 0.5 |
| DESCRIPTION = """ |
| This is a demo of https://huggingface.co/toynya/Z3D-E621-Convnext |
| I am not affiliated with the model author in anyway, this is just a useful tool requested by a user. |
| """ |
|
|
|
|
| def prepare_image(image: Image.Image, target_size: int): |
| |
| image_shape = image.size |
| max_dim = max(image_shape) |
| pad_left = (max_dim - image_shape[0]) // 2 |
| pad_top = (max_dim - image_shape[1]) // 2 |
|
|
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
| padded_image.paste(image, (pad_left, pad_top)) |
|
|
| |
| if max_dim != target_size: |
| padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) |
| |
| |
| |
| image_array = np.asarray(padded_image, dtype=np.float32) |
|
|
| |
| image_array = image_array[:, :, ::-1] |
|
|
| return np.expand_dims(image_array, axis=0) |
|
|
|
|
| def predict(image: Image.Image): |
| image_array = prepare_image(image, 448) |
|
|
| image_array = prepare_image(image, 448) |
| input_name = 'input_1:0' |
| output_name = 'predictions_sigmoid' |
|
|
| result = session.run([output_name], {input_name: image_array}) |
| result = result[0][0] |
|
|
| scores = {tags[i]: result[i] for i in range(len(result))} |
| predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD] |
| tag_string = ', '.join(predicted_tags) |
|
|
| return tag_string, scores |
|
|
|
|
| print("Downloading model...") |
| path = Path(huggingface_hub.snapshot_download(MODEL_REPO)) |
| print("Loading model...") |
| session = rt.InferenceSession(path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) |
|
|
| with open(path / 'tags-selected.csv', mode='r', encoding='utf-8') as file: |
| csv_reader = csv.DictReader(file) |
| tags = [row['name'].strip() for row in csv_reader] |
|
|
| print("Starting server...") |
|
|
| gradio_app = gr.Interface( |
| predict, |
| inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), |
| outputs=[ |
| gr.Textbox(label="Tag String"), |
| gr.Label(label="Tag Predictions", num_top_classes=100), |
| ], |
| title="JoyTag", |
| description=DESCRIPTION, |
| allow_flagging="never", |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| gradio_app.launch() |
|
|