Spaces:
Running
Running
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as rt | |
| import pandas as pd | |
| from PIL import Image | |
| EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| def load_labels(dataframe) -> list[str]: | |
| name_series = dataframe["name"] | |
| tag_names = name_series.tolist() | |
| rating_indexes = list(np.where(dataframe["category"] == 9)[0]) | |
| general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
| character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
| return tag_names, rating_indexes, general_indexes, character_indexes | |
| class Predictor: | |
| def __init__(self): | |
| self.model_target_size = None | |
| self.load_model(EVA02_LARGE_MODEL_DSV3_REPO) | |
| def download_model(self, model_repo): | |
| csv_path = huggingface_hub.hf_hub_download( | |
| model_repo, | |
| LABEL_FILENAME, | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| model_repo, | |
| MODEL_FILENAME, | |
| ) | |
| return csv_path, model_path | |
| def load_model(self, model_repo): | |
| csv_path, model_path = self.download_model(model_repo) | |
| tags_df = pd.read_csv(csv_path) | |
| sep_tags = load_labels(tags_df) | |
| self.tag_names = sep_tags[0] | |
| self.rating_indexes = sep_tags[1] | |
| self.general_indexes = sep_tags[2] | |
| self.character_indexes = sep_tags[3] | |
| model = rt.InferenceSession(model_path) | |
| _, height, width, _ = model.get_inputs()[0].shape | |
| self.model_target_size = height | |
| self.model = model | |
| def prepare_image(self, image): | |
| target_size = self.model_target_size | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| # Pad image to square | |
| image_shape = image.size | |
| max_dim = max(image_shape) | |
| pad_left = (max_dim - image_shape[0]) // 2 | |
| pad_top = (max_dim - image_shape[1]) // 2 | |
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| padded_image.paste(image, (pad_left, pad_top)) | |
| # Resize | |
| if max_dim != target_size: | |
| padded_image = padded_image.resize( | |
| (target_size, target_size), | |
| Image.BICUBIC, | |
| ) | |
| # Convert to numpy array | |
| image_array = np.asarray(padded_image, dtype=np.float32) | |
| # Convert PIL-native RGB to BGR | |
| image_array = image_array[:, :, ::-1] | |
| return np.expand_dims(image_array, axis=0) | |
| def predict(self, image, general_thresh): | |
| image = self.prepare_image(image) | |
| input_name = self.model.get_inputs()[0].name | |
| label_name = self.model.get_outputs()[0].name | |
| preds = self.model.run([label_name], {input_name: image})[0] | |
| labels = list(zip(self.tag_names, preds[0].astype(float))) | |
| # First 4 labels are actually ratings: pick one with argmax | |
| ratings_names = [labels[i] for i in self.rating_indexes] | |
| ratings_names = dict(ratings_names) | |
| ratings_names = sorted( | |
| ratings_names.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| # Then we have general tags: pick any where prediction confidence > threshold | |
| general_names = [labels[i] for i in self.general_indexes] | |
| general_res = [x for x in general_names if x[1] > general_thresh] | |
| general_res = dict(general_res) | |
| ratings = "rating:" + ratings_names[0][0] | |
| if ratings_names[0][0] == "general": | |
| ratings = "rating:safe" | |
| general_res[ratings] = ratings_names[0][1] | |
| general_res = sorted( | |
| general_res.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| return dict(general_res) | |
| predictor = Predictor() | |
| def genTag(image: PIL.Image.Image, score_threshold: float): | |
| return predictor.predict(image, score_threshold) |