| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import numpy as np |
| | import csv |
| | import onnxruntime as ort |
| |
|
| | from PIL import Image |
| | from onnxruntime import InferenceSession |
| | from modules.config import path_clip_vision |
| | from modules.model_loader import load_file_from_url |
| |
|
| |
|
| | global_model = None |
| | global_csv = None |
| |
|
| |
|
| | def default_interrogator(image_rgb, threshold=0.35, character_threshold=0.85, exclude_tags=""): |
| | global global_model, global_csv |
| |
|
| | model_name = "wd-v1-4-moat-tagger-v2" |
| |
|
| | model_onnx_filename = load_file_from_url( |
| | url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx', |
| | model_dir=path_clip_vision, |
| | file_name=f'{model_name}.onnx', |
| | ) |
| |
|
| | model_csv_filename = load_file_from_url( |
| | url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv', |
| | model_dir=path_clip_vision, |
| | file_name=f'{model_name}.csv', |
| | ) |
| |
|
| | if global_model is not None: |
| | model = global_model |
| | else: |
| | model = InferenceSession(model_onnx_filename, providers=ort.get_available_providers()) |
| | global_model = model |
| |
|
| | input = model.get_inputs()[0] |
| | height = input.shape[1] |
| |
|
| | image = Image.fromarray(image_rgb) |
| | ratio = float(height)/max(image.size) |
| | new_size = tuple([int(x*ratio) for x in image.size]) |
| | image = image.resize(new_size, Image.LANCZOS) |
| | square = Image.new("RGB", (height, height), (255, 255, 255)) |
| | square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) |
| |
|
| | image = np.array(square).astype(np.float32) |
| | image = image[:, :, ::-1] |
| | image = np.expand_dims(image, 0) |
| |
|
| | if global_csv is not None: |
| | csv_lines = global_csv |
| | else: |
| | csv_lines = [] |
| | with open(model_csv_filename) as f: |
| | reader = csv.reader(f) |
| | next(reader) |
| | for row in reader: |
| | csv_lines.append(row) |
| | global_csv = csv_lines |
| |
|
| | tags = [] |
| | general_index = None |
| | character_index = None |
| | for line_num, row in enumerate(csv_lines): |
| | if general_index is None and row[2] == "0": |
| | general_index = line_num |
| | elif character_index is None and row[2] == "4": |
| | character_index = line_num |
| | tags.append(row[1]) |
| |
|
| | label_name = model.get_outputs()[0].name |
| | probs = model.run([label_name], {input.name: image})[0] |
| |
|
| | result = list(zip(tags, probs[0])) |
| |
|
| | general = [item for item in result[general_index:character_index] if item[1] > threshold] |
| | character = [item for item in result[character_index:] if item[1] > character_threshold] |
| |
|
| | all = character + general |
| | remove = [s.strip() for s in exclude_tags.lower().split(",")] |
| | all = [tag for tag in all if tag[0] not in remove] |
| |
|
| | res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ') |
| | return res |
| |
|