Alptekinege commited on
Commit
5538a88
·
verified ·
1 Parent(s): 45ee4a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -4
app.py CHANGED
@@ -1,7 +1,319 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
  import gradio as gr
4
+ import huggingface_hub
5
+ import numpy as np
6
+ import onnxruntime as rt
7
+ import pandas as pd
8
+ from PIL import Image
9
+ import json # Added for loading metadata.json from the inference file
10
 
11
+ TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
 
15
+ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
+ """
17
+
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
19
+
20
+ # Dataset v3 series of models:
21
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
25
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
26
+
27
+ # Dataset v2 series of models:
28
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
29
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
30
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
31
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
32
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
33
+
34
+ # IdolSankaku series of models:
35
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
36
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
37
+
38
+ # Files to download from the repos
39
+ MODEL_FILENAME = "model.onnx"
40
+ LABEL_FILENAME = "selected_tags.csv"
41
+
42
+ kaomojis = [
43
+ "0_0",
44
+ "(o)_(o)",
45
+ "+_+",
46
+ "+_-",
47
+ "._.",
48
+ "<o>_<o>",
49
+ "<|>_<|>",
50
+ "=_=",
51
+ ">_<",
52
+ "3_3",
53
+ "6_9",
54
+ ">_o",
55
+ "@_@",
56
+ "^_^",
57
+ "o_o",
58
+ "u_u",
59
+ "x_x",
60
+ "|_|",
61
+ "||_||",
62
+ ]
63
+
64
+ def parse_args() -> argparse.Namespace:
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
67
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
68
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
69
+ return parser.parse_args()
70
+
71
+ def load_labels(dataframe) -> list[str]:
72
+ name_series = dataframe["name"]
73
+ name_series = name_series.map(
74
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
75
+ )
76
+ tag_names = name_series.tolist()
77
+
78
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
79
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
80
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
81
+ return tag_names, rating_indexes, general_indexes, character_indexes
82
+
83
+ def mcut_threshold(probs):
84
+ sorted_probs = probs[probs.argsort()[::-1]]
85
+ difs = sorted_probs[:-1] - sorted_probs[1:]
86
+ t = difs.argmax()
87
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
88
+ return thresh
89
+
90
+ class Predictor:
91
+ def __init__(self):
92
+ self.model_target_size = None
93
+ self.last_loaded_repo = None
94
+ # Added flag to distinguish between custom and Hugging Face models
95
+ self.is_custom_model = False
96
+
97
+ def download_model(self, model_repo):
98
+ csv_path = huggingface_hub.hf_hub_download(
99
+ model_repo,
100
+ LABEL_FILENAME,
101
+ use_auth_token=HF_TOKEN,
102
+ )
103
+ model_path = huggingface_hub.hf_hub_download(
104
+ model_repo,
105
+ MODEL_FILENAME,
106
+ use_auth_token=HF_TOKEN,
107
+ )
108
+ return csv_path, model_path
109
+
110
+ def load_model(self, model_repo, onnx_path=None, metadata_path=None):
111
+ # Modified to accept onnx_path and metadata_path for custom model support
112
+ if model_repo == "Custom Model" and onnx_path and metadata_path:
113
+ # Check if the custom model files have already been loaded
114
+ if self.last_loaded_repo == (onnx_path, metadata_path):
115
+ return
116
+ self.is_custom_model = True
117
+ # Load the ONNX model from the provided path (from inference file)
118
+ self.model = rt.InferenceSession(onnx_path)
119
+ # Load metadata from metadata.json (from inference file)
120
+ with open(metadata_path, "r", encoding="utf-8") as f:
121
+ metadata = json.load(f)
122
+ self.idx_to_tag = metadata["idx_to_tag"]
123
+ # Create tag_names list from idx_to_tag dictionary
124
+ self.tag_names = [self.idx_to_tag[str(i)] for i in range(len(self.idx_to_tag))]
125
+ # Set target size to 512 for custom model, as per inference file
126
+ self.model_target_size = 512
127
+ self.last_loaded_repo = (onnx_path, metadata_path)
128
+ else:
129
+ # Existing logic for Hugging Face models
130
+ self.is_custom_model = False
131
+ if self.last_loaded_repo == model_repo:
132
+ return
133
+ csv_path, model_path = self.download_model(model_repo)
134
+ tags_df = pd.read_csv(csv_path)
135
+ sep_tags = load_labels(tags_df)
136
+ self.tag_names = sep_tags[0]
137
+ self.rating_indexes = sep_tags[1]
138
+ self.general_indexes = sep_tags[2]
139
+ self.character_indexes = sep_tags[3]
140
+ self.model = rt.InferenceSession(model_path)
141
+ _, height, width, _ = self.model.get_inputs()[0].shape
142
+ self.model_target_size = height
143
+ self.last_loaded_repo = model_repo
144
+
145
+ def prepare_image(self, image):
146
+ if self.is_custom_model:
147
+ # Added preprocessing logic from inference file's preprocess_image function
148
+ # Adapted to take a PIL image instead of a file path
149
+ target_size = self.model_target_size
150
+ img = image.convert("RGB")
151
+ w, h = img.size
152
+ aspect = w / h
153
+ if aspect > 1:
154
+ new_w = target_size
155
+ new_h = int(new_w / aspect)
156
+ else:
157
+ new_h = target_size
158
+ new_w = int(new_h * aspect)
159
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
160
+ background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
161
+ paste_x = (target_size - new_w) // 2
162
+ paste_y = (target_size - new_h) // 2
163
+ background.paste(img, (paste_x, paste_y))
164
+ arr = np.array(background).astype("float32") / 255.0
165
+ arr = np.transpose(arr, (2, 0, 1)) # HWC to CHW as per inference file
166
+ arr = np.expand_dims(arr, axis=0)
167
+ return arr
168
+ else:
169
+ # Existing preprocessing logic for Hugging Face models
170
+ target_size = self.model_target_size
171
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
172
+ canvas.alpha_composite(image)
173
+ image = canvas.convert("RGB")
174
+ image_shape = image.size
175
+ max_dim = max(image_shape)
176
+ pad_left = (max_dim - image_shape[0]) // 2
177
+ pad_top = (max_dim - image_shape[1]) // 2
178
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
179
+ padded_image.paste(image, (pad_left, pad_top))
180
+ if max_dim != target_size:
181
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
182
+ image_array = np.asarray(padded_image, dtype=np.float32)
183
+ image_array = image_array[:, :, ::-1] # RGB to BGR
184
+ return np.expand_dims(image_array, axis=0)
185
+
186
+ def predict(
187
+ self,
188
+ image,
189
+ model_repo,
190
+ general_thresh,
191
+ general_mcut_enabled,
192
+ character_thresh,
193
+ character_mcut_enabled,
194
+ onnx_path=None,
195
+ metadata_path=None,
196
+ ):
197
+ # Modified to accept onnx_path and metadata_path for custom model
198
+ self.load_model(model_repo, onnx_path, metadata_path)
199
+ # Added check to ensure custom model files are provided
200
+ if self.is_custom_model and (onnx_path is None or metadata_path is None):
201
+ return "Please upload ONNX model and metadata JSON files.", {}, {}, {}
202
+ image_tensor = self.prepare_image(image)
203
+ input_name = self.model.get_inputs()[0].name
204
+ # Changed to use None for output names to get all outputs, supporting custom model
205
+ outputs = self.model.run(None, {input_name: image_tensor})
206
+ if self.is_custom_model:
207
+ # Added inference logic from inference file for custom model
208
+ # Handle case where model might output initial and refined predictions
209
+ refined_preds = outputs[1] if len(outputs) == 2 else outputs[0]
210
+ ref_logit = refined_preds[0] # Shape (N_tags,)
211
+ # Apply sigmoid to convert logits to probabilities (from inference file)
212
+ ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
213
+ pred_indices = np.where(ref_prob >= general_thresh)[0]
214
+ predicted_tags = [self.tag_names[idx] for idx in pred_indices]
215
+ sorted_general_strings = ", ".join(predicted_tags)
216
+ # Custom model doesn't use category separation, so return empty for rating and character
217
+ rating = {}
218
+ character_res = {}
219
+ general_res = {self.tag_names[idx]: ref_prob[idx] for idx in pred_indices}
220
+ else:
221
+ # Existing inference logic for Hugging Face models
222
+ preds = outputs[0] # Assumes single output tensor
223
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
224
+ ratings_names = [labels[i] for i in self.rating_indexes]
225
+ rating = dict(ratings_names)
226
+ general_names = [labels[i] for i in self.general_indexes]
227
+ if general_mcut_enabled:
228
+ general_probs = np.array([x[1] for x in general_names])
229
+ general_thresh = mcut_threshold(general_probs)
230
+ general_res = [x for x in general_names if x[1] > general_thresh]
231
+ general_res = dict(general_res)
232
+ character_names = [labels[i] for i in self.character_indexes]
233
+ if character_mcut_enabled:
234
+ character_probs = np.array([x[1] for x in character_names])
235
+ character_thresh = mcut_threshold(character_probs)
236
+ character_thresh = max(0.15, character_thresh)
237
+ character_res = [x for x in character_names if x[1] > character_thresh]
238
+ character_res = dict(character_res)
239
+ sorted_general_strings = sorted(
240
+ general_res.items(),
241
+ key=lambda x: x[1],
242
+ reverse=True,
243
+ )
244
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
245
+ sorted_general_strings = ", ".join(sorted_general_strings).replace("(", r"\(").replace(")", r"\)")
246
+ return sorted_general_strings, rating, character_res, general_res
247
+
248
+ def main():
249
+ args = parse_args()
250
+ predictor = Predictor()
251
+ # Added "Custom Model" to the dropdown list to support local ONNX model
252
+ dropdown_list = [
253
+ SWINV2_MODEL_DSV3_REPO,
254
+ CONV_MODEL_DSV3_REPO,
255
+ VIT_MODEL_DSV3_REPO,
256
+ VIT_LARGE_MODEL_DSV3_REPO,
257
+ EVA02_LARGE_MODEL_DSV3_REPO,
258
+ MOAT_MODEL_DSV2_REPO,
259
+ SWIN_MODEL_DSV2_REPO,
260
+ CONV_MODEL_DSV2_REPO,
261
+ CONV2_MODEL_DSV2_REPO,
262
+ VIT_MODEL_DSV2_REPO,
263
+ SWINV2_MODEL_IS_DSV1_REPO,
264
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
265
+ "Custom Model",
266
+ ]
267
+ with gr.Blocks(title=TITLE) as demo:
268
+ with gr.Column():
269
+ gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
270
+ gr.Markdown(value=DESCRIPTION)
271
+ with gr.Row():
272
+ with gr.Column(variant="panel"):
273
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
274
+ model_repo = gr.Dropdown(dropdown_list, value=SWINV2_MODEL_DSV3_REPO, label="Model")
275
+ # Added file inputs for ONNX model and metadata, hidden by default
276
+ with gr.Row(visible=False) as custom_model_inputs:
277
+ onnx_file = gr.File(label="ONNX Model File", file_types=[".onnx"])
278
+ metadata_file = gr.File(label="Metadata JSON File", file_types=[".json"])
279
+ with gr.Row():
280
+ general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3)
281
+ general_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
282
+ with gr.Row():
283
+ character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3)
284
+ character_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1)
285
+ with gr.Row():
286
+ # Updated clear button to include new file inputs
287
+ clear = gr.ClearButton(
288
+ components=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, onnx_file, metadata_file],
289
+ variant="secondary",
290
+ size="lg"
291
+ )
292
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
293
+ with gr.Column(variant="panel"):
294
+ sorted_general_strings = gr.Textbox(label="Output (string)")
295
+ rating = gr.Label(label="Rating")
296
+ character_res = gr.Label(label="Output (characters)")
297
+ general_res = gr.Label(label="Output (tags)")
298
+ clear.add([sorted_general_strings, rating, character_res, general_res])
299
+ # Added event listener to show/hide custom model inputs based on model selection
300
+ model_repo.change(
301
+ lambda x: gr.update(visible=(x == "Custom Model")),
302
+ inputs=model_repo,
303
+ outputs=custom_model_inputs,
304
+ )
305
+ # Updated submit event to pass onnx_file and metadata_file to predict
306
+ submit.click(
307
+ predictor.predict,
308
+ inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, onnx_file, metadata_file],
309
+ outputs=[sorted_general_strings, rating, character_res, general_res],
310
+ )
311
+ gr.Examples(
312
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
313
+ inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled],
314
+ )
315
+ demo.queue(max_size=10)
316
+ demo.launch()
317
+
318
+ if __name__ == "__main__":
319
+ main()