FrostIce commited on
Commit
c8b2c02
·
verified ·
1 Parent(s): 900e8ea

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (2).py +184 -0
  2. requirements (2).txt +8 -0
app (2).py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Mapping, Tuple, Dict
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from huggingface_hub import hf_hub_download
11
+ from onnxruntime import InferenceSession
12
+
13
+
14
+ # noinspection PyUnresolvedReferences
15
+ def make_square(img, target_size):
16
+ old_size = img.shape[:2]
17
+ desired_size = max(old_size)
18
+ desired_size = max(desired_size, target_size)
19
+
20
+ delta_w = desired_size - old_size[1]
21
+ delta_h = desired_size - old_size[0]
22
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
23
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
24
+
25
+ color = [255, 255, 255]
26
+ return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
27
+
28
+
29
+ # noinspection PyUnresolvedReferences
30
+ def smart_resize(img, size):
31
+ # Assumes the image has already gone through make_square
32
+ if img.shape[0] > size:
33
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
34
+ elif img.shape[0] < size:
35
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
36
+ else: # just do nothing
37
+ pass
38
+
39
+ return img
40
+
41
+
42
+ class WaifuDiffusionInterrogator:
43
+ def __init__(
44
+ self,
45
+ repo='SmilingWolf/wd-v1-4-vit-tagger',
46
+ model_path='model.onnx',
47
+ tags_path='selected_tags.csv',
48
+ mode: str = "auto"
49
+ ) -> None:
50
+ self.__repo = repo
51
+ self.__model_path = model_path
52
+ self.__tags_path = tags_path
53
+ self._provider_mode = mode
54
+
55
+ self.__initialized = False
56
+ self._model, self._tags = None, None
57
+
58
+ def _init(self) -> None:
59
+ if self.__initialized:
60
+ return
61
+
62
+ model_path = hf_hub_download(self.__repo, filename=self.__model_path)
63
+ tags_path = hf_hub_download(self.__repo, filename=self.__tags_path)
64
+
65
+ self._model = InferenceSession(str(model_path))
66
+ self._tags = pd.read_csv(tags_path)
67
+
68
+ self.__initialized = True
69
+
70
+ def _calculation(self, image: Image.Image) -> pd.DataFrame:
71
+ self._init()
72
+
73
+ # code for converting the image and running the model is taken from the link below
74
+ # thanks, SmilingWolf!
75
+ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
76
+
77
+ # convert an image to fit the model
78
+ _, height, _, _ = self._model.get_inputs()[0].shape
79
+
80
+ # alpha to white
81
+ image = image.convert('RGBA')
82
+ new_image = Image.new('RGBA', image.size, 'WHITE')
83
+ new_image.paste(image, mask=image)
84
+ image = new_image.convert('RGB')
85
+ image = np.asarray(image)
86
+
87
+ # PIL RGB to OpenCV BGR
88
+ image = image[:, :, ::-1]
89
+
90
+ image = make_square(image, height)
91
+ image = smart_resize(image, height)
92
+ image = image.astype(np.float32)
93
+ image = np.expand_dims(image, 0)
94
+
95
+ # evaluate model
96
+ input_name = self._model.get_inputs()[0].name
97
+ label_name = self._model.get_outputs()[0].name
98
+ confidence = self._model.run([label_name], {input_name: image})[0]
99
+
100
+ full_tags = self._tags[['name', 'category']].copy()
101
+ full_tags['confidence'] = confidence[0]
102
+
103
+ return full_tags
104
+
105
+ def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]:
106
+ full_tags = self._calculation(image)
107
+
108
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
109
+ ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values)
110
+
111
+ # rest are regular tags
112
+ tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values)
113
+
114
+ return ratings, tags
115
+
116
+
117
+ WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = {
118
+ 'wd14-vit': WaifuDiffusionInterrogator(),
119
+ 'wd14-convnext': WaifuDiffusionInterrogator(
120
+ repo='SmilingWolf/wd-v1-4-convnext-tagger'
121
+ ),
122
+ }
123
+ RE_SPECIAL = re.compile(r'([\\()])')
124
+
125
+
126
+ def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float,
127
+ use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
128
+ -> Tuple[Mapping[str, float], str, Mapping[str, float]]:
129
+ model = WAIFU_MODELS[model_name]
130
+ ratings, tags = model.interrogate(image)
131
+
132
+ filtered_tags = {
133
+ tag: score for tag, score in tags.items()
134
+ if score >= threshold
135
+ }
136
+
137
+ text_items = []
138
+ tags_pairs = filtered_tags.items()
139
+ if score_descend:
140
+ tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
141
+ for tag, score in tags_pairs:
142
+ tag_outformat = tag
143
+ if use_spaces:
144
+ tag_outformat = tag_outformat.replace('_', ' ')
145
+ if use_escape:
146
+ tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
147
+ if include_ranks:
148
+ tag_outformat = f"({tag_outformat}:{score:.3f})"
149
+ text_items.append(tag_outformat)
150
+ output_text = ', '.join(text_items)
151
+
152
+ return ratings, output_text, filtered_tags
153
+
154
+
155
+ if __name__ == '__main__':
156
+ with gr.Blocks() as demo:
157
+ with gr.Row():
158
+ with gr.Column():
159
+ gr_input_image = gr.Image(type='pil', label='Original Image')
160
+ with gr.Row():
161
+ gr_model = gr.Radio(list(WAIFU_MODELS.keys()), value='wd14-vit', label='Waifu Model')
162
+ gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Tagging Confidence Threshold')
163
+ with gr.Row():
164
+ gr_space = gr.Checkbox(value=False, label='Use Space Instead Of _')
165
+ gr_escape = gr.Checkbox(value=True, label='Use Text Escape')
166
+ gr_confidence = gr.Checkbox(value=False, label='Keep Confidences')
167
+ gr_order = gr.Checkbox(value=True, label='Descend By Confidence')
168
+
169
+ gr_btn_submit = gr.Button(value='Tagging', variant='primary')
170
+
171
+ with gr.Column():
172
+ gr_ratings = gr.Label(label='Ratings')
173
+ with gr.Tabs():
174
+ with gr.Tab("Tags"):
175
+ gr_tags = gr.Label(label='Tags')
176
+ with gr.Tab("Exported Text"):
177
+ gr_output_text = gr.TextArea(label='Exported Text')
178
+
179
+ gr_btn_submit.click(
180
+ image_to_wd14_tags,
181
+ inputs=[gr_input_image, gr_model, gr_threshold, gr_space, gr_escape, gr_confidence, gr_order],
182
+ outputs=[gr_ratings, gr_output_text, gr_tags],
183
+ )
184
+ demo.queue(os.cpu_count()).launch()
requirements (2).txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==3.16.1
2
+ numpy
3
+ pillow
4
+ onnxruntime
5
+ huggingface_hub
6
+ scikit-image
7
+ pandas
8
+ opencv-python>=4.6.0