Johnny-Z commited on
Commit
beefff8
·
verified ·
1 Parent(s): 93eed81

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +300 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPImageProcessor, AutoModel
2
+ import torch
3
+ import json
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import os
8
+ from huggingface_hub import login, snapshot_download
9
+
10
+ TITLE = "Danbooru Tagger"
11
+ DESCRIPTION = """
12
+ ## Dataset
13
+ - Source: Cleaned Danbooru
14
+
15
+ ## Metrics
16
+ - Validation Split: 10% of Dataset
17
+ - Validation Results:
18
+
19
+ ### General
20
+ | Metric | Value |
21
+ |-----------------|-------------|
22
+ | Macro F1 | 0.4678 |
23
+ | Macro Precision | 0.4605 |
24
+ | Macro Recall | 0.5229 |
25
+ | Micro F1 | 0.6661 |
26
+ | Micro Precision | 0.6049 |
27
+ | Micro Recall | 0.7411 |
28
+
29
+ ### Character
30
+ | Metric | Value |
31
+ |-----------------|-------------|
32
+ | Macro F1 | 0.8925 |
33
+ | Macro Precision | 0.9099 |
34
+ | Macro Recall | 0.8935 |
35
+ | Micro F1 | 0.9232 |
36
+ | Micro Precision | 0.9264 |
37
+ | Micro Recall | 0.9199 |
38
+
39
+ ### Artist
40
+ | Metric | Value |
41
+ |-----------------|-------------|
42
+ | Macro F1 | 0.7904 |
43
+ | Macro Precision | 0.8286 |
44
+ | Macro Recall | 0.7904 |
45
+ | Micro F1 | 0.5989 |
46
+ | Micro Precision | 0.5975 |
47
+ | Micro Recall | 0.6004 |
48
+ """
49
+
50
+ kaomojis = [
51
+ "0_0",
52
+ "(o)_(o)",
53
+ "+_+",
54
+ "+_-",
55
+ "._.",
56
+ "<o>_<o>",
57
+ "<|>_<|>",
58
+ "=_=",
59
+ ">_<",
60
+ "3_3",
61
+ "6_9",
62
+ ">_o",
63
+ "@_@",
64
+ "^_^",
65
+ "o_o",
66
+ "u_u",
67
+ "x_x",
68
+ "|_|",
69
+ "||_||",
70
+ ]
71
+
72
+ device = torch.device('cpu')
73
+ dtype = torch.float32
74
+
75
+ hf_token = os.getenv("HF_TOKEN")
76
+ if hf_token:
77
+ login(token=hf_token)
78
+ else:
79
+ raise ValueError("environment variable HF_TOKEN not found.")
80
+
81
+ repo_id = "Johnny-Z/vit-e4"
82
+ repo_dir = snapshot_download(repo_id)
83
+ model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device)
84
+
85
+ processor = CLIPImageProcessor.from_pretrained(repo_id)
86
+
87
+ class MultiheadAttentionPoolingHead(nn.Module):
88
+ def __init__(self, input_size):
89
+ super().__init__()
90
+
91
+ self.map_probe = nn.Parameter(torch.randn(1, 1, input_size))
92
+ self.map_layernorm0 = nn.LayerNorm(input_size, eps=1e-08)
93
+ self.map_attention = torch.nn.MultiheadAttention(input_size, input_size // 64, batch_first=True)
94
+ self.map_layernorm1 = nn.LayerNorm(input_size, eps=1e-08)
95
+ self.map_ffn = nn.Sequential(
96
+ nn.Linear(input_size, input_size * 4),
97
+ nn.SiLU(),
98
+ nn.Linear(input_size * 4, input_size)
99
+ )
100
+
101
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
102
+ batch_size = hidden_state.shape[0]
103
+ probe = self.map_probe.repeat(batch_size, 1, 1)
104
+
105
+ hidden_state = self.map_layernorm0(hidden_state)
106
+ hidden_state = self.map_attention(probe, hidden_state, hidden_state)[0]
107
+ hidden_state = self.map_layernorm1(hidden_state)
108
+
109
+ residual = hidden_state
110
+ hidden_state = residual + self.map_ffn(hidden_state)
111
+ return hidden_state[:, 0]
112
+
113
+ class MLP(nn.Module):
114
+ def __init__(self, input_size, class_num):
115
+ super().__init__()
116
+ self.mlp_layer0 = nn.Sequential(
117
+ nn.LayerNorm(input_size, eps=1e-08),
118
+ nn.Linear(input_size, input_size // 2),
119
+ nn.SiLU()
120
+ )
121
+ self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
122
+ self.sigmoid = nn.Sigmoid()
123
+
124
+ def forward(self, x):
125
+ x = self.mlp_layer0(x)
126
+ x = self.mlp_layer1(x)
127
+ x = self.sigmoid(x)
128
+ return x
129
+
130
+ with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
131
+ general_dict = json.load(f)
132
+
133
+ with open(os.path.join(repo_dir, 'character_tag_dict.json'), 'r', encoding='utf-8') as f:
134
+ character_dict = json.load(f)
135
+
136
+ with open(os.path.join(repo_dir, 'artist_tag_dict.json'), 'r', encoding='utf-8') as f:
137
+ artist_dict = json.load(f)
138
+
139
+ with open(os.path.join(repo_dir, 'implications_list.json'), 'r', encoding='utf-8') as f:
140
+ implications_list = json.load(f)
141
+
142
+ with open(os.path.join(repo_dir, 'artist_threshold.json'), 'r', encoding='utf-8') as f:
143
+ artist_thresholds = json.load(f)
144
+
145
+ with open(os.path.join(repo_dir, 'character_threshold.json'), 'r', encoding='utf-8') as f:
146
+ character_thresholds = json.load(f)
147
+
148
+ with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f:
149
+ general_thresholds = json.load(f)
150
+
151
+ model_map = MultiheadAttentionPoolingHead(2048)
152
+ model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
153
+ model_map.to(device).to(dtype).eval()
154
+
155
+ general_class = 9775
156
+ mlp_general = MLP(2048, general_class)
157
+ mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True))
158
+ mlp_general.to(device).to(dtype).eval()
159
+
160
+ character_class = 7568
161
+ mlp_character = MLP(2048, character_class)
162
+ mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True))
163
+ mlp_character.to(device).to(dtype).eval()
164
+
165
+ artist_class = 13957
166
+ mlp_artist = MLP(2048, artist_class)
167
+ mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
168
+ mlp_artist.to(device).to(dtype).eval()
169
+
170
+ def prediction_to_tag(prediction, tag_dict, class_num):
171
+ prediction = prediction.view(class_num)
172
+ predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
173
+
174
+ general = {}
175
+ character = {}
176
+ artist = {}
177
+ date = {}
178
+ rating = {}
179
+
180
+ for tag, value in tag_dict.items():
181
+ if value[2] in predicted_ids:
182
+ tag_value = round(prediction[value[2] - 1].item(), 6)
183
+ if value[1] == "general" and tag_value >= general_thresholds.get(tag, {}).get("Threshold", 0.75):
184
+ general[tag] = tag_value
185
+ elif value[1] == "character" and tag_value >= character_thresholds.get(tag, {}).get("Threshold", 0.75):
186
+ character[tag] = tag_value
187
+ elif value[1] == "artist" and tag_value >= artist_thresholds.get(tag, {}).get("Threshold", 0.75):
188
+ artist[tag] = tag_value
189
+ elif value[1] == "rating":
190
+ rating[tag] = tag_value
191
+ elif value[1] == "date":
192
+ date[tag] = tag_value
193
+
194
+ general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
195
+ character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
196
+ artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
197
+
198
+ if date:
199
+ date = {max(date, key=date.get): date[max(date, key=date.get)]}
200
+ if rating:
201
+ rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
202
+
203
+ return general, character, artist, date, rating
204
+
205
+ def process_image(image):
206
+ try:
207
+ image = image.convert('RGBA')
208
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
209
+ image = Image.alpha_composite(background, image).convert('RGB')
210
+
211
+ image_inputs = processor(images=[image], return_tensors="pt").to(device).to(dtype)
212
+
213
+ except (OSError, IOError) as e:
214
+ print(f"Error opening image: {e}")
215
+ return
216
+ with torch.no_grad():
217
+ embedding = model(image_inputs.pixel_values)
218
+
219
+ embedding = model_map(embedding)
220
+
221
+ general_prediction = mlp_general(embedding)
222
+ general_ = prediction_to_tag(general_prediction, general_dict, general_class)
223
+ general_tags = general_[0]
224
+ rating = general_[4]
225
+
226
+ character_prediction = mlp_character(embedding)
227
+ character_ = prediction_to_tag(character_prediction, character_dict, character_class)
228
+ character_tags = character_[1]
229
+
230
+ artist_prediction = mlp_artist(embedding)
231
+ artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
232
+ artist_tags = artist_[2]
233
+ date = artist_[3]
234
+
235
+ combined_tags = {**general_tags}
236
+
237
+ tags_list = [tag for tag in combined_tags]
238
+ remove_list = []
239
+ for tag in tags_list:
240
+ if tag in implications_list:
241
+ for implication in implications_list[tag]:
242
+ remove_list.append(implication)
243
+ tags_list = [tag for tag in tags_list if tag not in remove_list]
244
+ tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
245
+
246
+ tags_str = ", ".join(tags_list).replace("(", r"\(").replace(")", r"\)")
247
+
248
+ return tags_str, artist_tags, character_tags, general_tags, rating, date
249
+
250
+ def main():
251
+ with gr.Blocks(title=TITLE) as demo:
252
+ with gr.Column():
253
+ gr.Markdown(
254
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
255
+ )
256
+ with gr.Row():
257
+ with gr.Column(variant="panel"):
258
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
259
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
260
+ with gr.Row():
261
+ clear = gr.ClearButton(
262
+ components=[
263
+ image,
264
+ ],
265
+ variant="secondary",
266
+ size="lg",
267
+ )
268
+ gr.Markdown(value=DESCRIPTION)
269
+ with gr.Column(variant="panel"):
270
+ tags_str = gr.Textbox(label="Output", lines=4)
271
+ with gr.Row():
272
+ rating = gr.Label(label="Rating")
273
+ date = gr.Label(label="Year")
274
+ artist_tags = gr.Label(label="Artist")
275
+ character_tags = gr.Label(label="Character")
276
+ general_tags = gr.Label(label="General")
277
+ clear.add(
278
+ [
279
+ tags_str,
280
+ artist_tags,
281
+ general_tags,
282
+ character_tags,
283
+ rating,
284
+ date,
285
+ ]
286
+ )
287
+
288
+ submit.click(
289
+ process_image,
290
+ inputs=[
291
+ image
292
+ ],
293
+ outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
294
+ )
295
+
296
+ demo.queue(max_size=10)
297
+ demo.launch()
298
+
299
+ if __name__ == "__main__":
300
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio
5
+ einops
6
+ timm
7
+ accelerate