Johnny-Z commited on
Commit
f613759
·
verified ·
1 Parent(s): 7126f84

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -83
app.py CHANGED
@@ -13,41 +13,44 @@ from huggingface_hub import login, snapshot_download
13
  TITLE = "Danbooru Tagger"
14
  DESCRIPTION = """
15
  ## Dataset
16
- - Source: Cleaned Danbooru
17
-
18
- ## Metrics
19
  - Validation Split: 10% of Dataset
20
- - Validation Results:
 
21
 
22
  ### General
 
23
  | Metric | Value |
24
  |-----------------|-------------|
25
- | Macro F1 | 0.4678 |
26
- | Macro Precision | 0.4605 |
27
- | Macro Recall | 0.5229 |
28
- | Micro F1 | 0.6661 |
29
- | Micro Precision | 0.6049 |
30
- | Micro Recall | 0.7411 |
31
 
32
  ### Character
 
33
  | Metric | Value |
34
  |-----------------|-------------|
35
- | Macro F1 | 0.8925 |
36
- | Macro Precision | 0.9099 |
37
- | Macro Recall | 0.8935 |
38
- | Micro F1 | 0.9232 |
39
- | Micro Precision | 0.9264 |
40
- | Micro Recall | 0.9199 |
41
 
42
  ### Artist
 
43
  | Metric | Value |
44
  |-----------------|-------------|
45
- | Macro F1 | 0.7904 |
46
- | Macro Precision | 0.8286 |
47
- | Macro Recall | 0.7904 |
48
- | Micro F1 | 0.5989 |
49
- | Micro Precision | 0.5975 |
50
- | Micro Recall | 0.6004 |
51
  """
52
 
53
  kaomojis = [
@@ -81,10 +84,10 @@ if hf_token:
81
  else:
82
  raise ValueError("environment variable HF_TOKEN not found.")
83
 
84
- repo = snapshot_download('Johnny-Z/vit-e4')
85
  model = AutoModel.from_pretrained(repo, dtype=dtype, trust_remote_code=True, device_map=device)
86
 
87
- index_dir = snapshot_download('Johnny-Z/dan_index', repo_type='dataset')
88
 
89
  processor = CLIPImageProcessor.from_pretrained(repo)
90
 
@@ -131,30 +134,11 @@ class MLP(nn.Module):
131
  x = self.sigmoid(x)
132
  return x
133
 
134
- class MLP_Retrieval(nn.Module):
135
- def __init__(self, input_size, class_num):
136
- super().__init__()
137
- self.mlp_layer0 = nn.Sequential(
138
- nn.Linear(input_size, input_size // 2),
139
- nn.SiLU()
140
- )
141
- self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
142
-
143
- def forward(self, x):
144
- x = self.mlp_layer0(x)
145
- x = self.mlp_layer1(x)
146
- x1, x2 = x[:, :15], x[:, 15:]
147
- x1 = torch.softmax(x1, dim=1)
148
- x2 = torch.softmax(x2, dim=1)
149
- x = torch.cat([x1, x2], dim=1)
150
-
151
- return x
152
-
153
  class MLP_R(nn.Module):
154
  def __init__(self, input_size):
155
  super().__init__()
156
  self.mlp_layer0 = nn.Sequential(
157
- nn.Linear(input_size, 256),
158
  )
159
 
160
  def forward(self, x):
@@ -186,25 +170,21 @@ model_map = MultiheadAttentionPoolingHead(2048)
186
  model_map.load_state_dict(torch.load(os.path.join(repo, "map_head.pth"), map_location=device, weights_only=True))
187
  model_map.to(device).to(dtype).eval()
188
 
189
- general_class = 9775
190
  mlp_general = MLP(2048, general_class)
191
  mlp_general.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_general.pth"), map_location=device, weights_only=True))
192
  mlp_general.to(device).to(dtype).eval()
193
 
194
- character_class = 7568
195
  mlp_character = MLP(2048, character_class)
196
  mlp_character.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_character.pth"), map_location=device, weights_only=True))
197
  mlp_character.to(device).to(dtype).eval()
198
 
199
- artist_class = 13957
200
  mlp_artist = MLP(2048, artist_class)
201
  mlp_artist.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
202
  mlp_artist.to(device).to(dtype).eval()
203
 
204
- mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
205
- mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
206
- mlp_artist_retrieval.to(device).to(dtype).eval()
207
-
208
  mlp_r = MLP_R(2048)
209
  mlp_r.load_state_dict(torch.load(os.path.join(repo, "retrieval_head.pth"), map_location=device, weights_only=True))
210
  mlp_r.to(device).to(dtype).eval()
@@ -244,29 +224,6 @@ def prediction_to_tag(prediction, tag_dict, class_num):
244
 
245
  return general, character, artist, date, rating
246
 
247
- def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
248
- prediction = prediction.view(class_num)
249
- predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
250
-
251
- artist = {}
252
- date = {}
253
-
254
- for tag, value in tag_dict.items():
255
- if value[2] in predicted_ids:
256
- tag_value = round(prediction[value[2] - 1].item(), 6)
257
- if value[1] == "artist":
258
- artist[tag] = tag_value
259
- elif value[1] == "date":
260
- date[tag] = tag_value
261
-
262
- artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
263
- artist = dict(list(artist.items())[:top_k])
264
-
265
- if date:
266
- date = {max(date, key=date.get): date[max(date, key=date.get)]}
267
-
268
- return artist, date
269
-
270
  def load_id_map(id_map_path):
271
  with open(id_map_path, "r") as f:
272
  id_map = json.load(f)
@@ -309,7 +266,7 @@ def search_index(query_vector, k=32, distance_threshold_min=0, distance_threshol
309
 
310
  return results
311
 
312
- def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.25, timeout=4.0):
313
  pairs = []
314
  for item in retrieval_results:
315
  oid = item.get("original_id")
@@ -332,9 +289,10 @@ def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.25, timeout=4.0):
332
  url = "https:" + url
333
  elif url.startswith("/"):
334
  url = "https://danbooru.donmai.us" + url
335
- pairs.append((url, oid))
336
- except Exception:
337
 
 
 
 
338
  pass
339
  finally:
340
 
@@ -363,7 +321,18 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
363
 
364
  url_id_pairs = fetch_retrieval_image_urls(retrieval_results)
365
 
366
- retrieval_gallery_items = [(url, f"https://danbooru.donmai.us/posts/{oid}") for url, oid in url_id_pairs]
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  general_prediction = mlp_general(embedding)
369
  general_ = prediction_to_tag(general_prediction, general_dict, general_class)
@@ -374,10 +343,10 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
374
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
375
  character_tags = character_[1]
376
 
377
- artist_retrieval_prediction = mlp_artist_retrieval(embedding)
378
- artist_retrieval_ = prediction_to_retrieval(artist_retrieval_prediction, artist_dict, artist_class, 10)
379
- artist_tags = artist_retrieval_[0]
380
- date = artist_retrieval_[1]
381
 
382
  combined_tags = {**general_tags}
383
 
@@ -400,6 +369,7 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
400
  rating,
401
  date,
402
  retrieval_gallery_items,
 
403
  )
404
 
405
  def main():
@@ -414,7 +384,7 @@ def main():
414
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
415
  k_slider = gr.Slider(1, 100, value=32, step=1, label="Top K Results")
416
  distance_min_slider = gr.Slider(0, 128, value=0, step=1, label="Min Distance Threshold")
417
- distance_max_slider = gr.Slider(0, 128, value=80, step=1, label="Max Distance Threshold")
418
  with gr.Row():
419
  clear = gr.ClearButton(
420
  components=[
@@ -440,6 +410,8 @@ def main():
440
  label="Retrieval Preview",
441
  columns=5,
442
  )
 
 
443
  clear.add(
444
  [
445
  tags_str,
@@ -449,6 +421,7 @@ def main():
449
  rating,
450
  date,
451
  retrieval_gallery,
 
452
  ]
453
  )
454
 
@@ -463,6 +436,7 @@ def main():
463
  rating,
464
  date,
465
  retrieval_gallery,
 
466
  ],
467
  )
468
 
 
13
  TITLE = "Danbooru Tagger"
14
  DESCRIPTION = """
15
  ## Dataset
16
+ - Source: Danbooru
17
+ - Cutoff Date: 2025-11-27
 
18
  - Validation Split: 10% of Dataset
19
+
20
+ ## Validation Results
21
 
22
  ### General
23
+ Tags Count: 11046
24
  | Metric | Value |
25
  |-----------------|-------------|
26
+ | Macro F1 | 0.4439 |
27
+ | Macro Precision | 0.4168 |
28
+ | Macro Recall | 0.4964 |
29
+ | Micro F1 | 0.6595 |
30
+ | Micro Precision | 0.5982 |
31
+ | Micro Recall | 0.7349 |
32
 
33
  ### Character
34
+ Tags Count: 9148
35
  | Metric | Value |
36
  |-----------------|-------------|
37
+ | Macro F1 | 0.8646 |
38
+ | Macro Precision | 0.8897 |
39
+ | Macro Recall | 0.8492 |
40
+ | Micro F1 | 0.9092 |
41
+ | Micro Precision | 0.9195 |
42
+ | Micro Recall | 0.8991 |
43
 
44
  ### Artist
45
+ Tags Count: 17171
46
  | Metric | Value |
47
  |-----------------|-------------|
48
+ | Macro F1 | 0.8008 |
49
+ | Macro Precision | 0.8669 |
50
+ | Macro Recall | 0.7641 |
51
+ | Micro F1 | 0.8596 |
52
+ | Micro Precision | 0.8948 |
53
+ | Micro Recall | 0.8271 |
54
  """
55
 
56
  kaomojis = [
 
84
  else:
85
  raise ValueError("environment variable HF_TOKEN not found.")
86
 
87
+ repo = snapshot_download('Johnny-Z/danbooru_vfm')
88
  model = AutoModel.from_pretrained(repo, dtype=dtype, trust_remote_code=True, device_map=device)
89
 
90
+ index_dir = snapshot_download('Johnny-Z/dan_index_a', repo_type='dataset')
91
 
92
  processor = CLIPImageProcessor.from_pretrained(repo)
93
 
 
134
  x = self.sigmoid(x)
135
  return x
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  class MLP_R(nn.Module):
138
  def __init__(self, input_size):
139
  super().__init__()
140
  self.mlp_layer0 = nn.Sequential(
141
+ nn.Linear(input_size, 384),
142
  )
143
 
144
  def forward(self, x):
 
170
  model_map.load_state_dict(torch.load(os.path.join(repo, "map_head.pth"), map_location=device, weights_only=True))
171
  model_map.to(device).to(dtype).eval()
172
 
173
+ general_class = 11046
174
  mlp_general = MLP(2048, general_class)
175
  mlp_general.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_general.pth"), map_location=device, weights_only=True))
176
  mlp_general.to(device).to(dtype).eval()
177
 
178
+ character_class = 9148
179
  mlp_character = MLP(2048, character_class)
180
  mlp_character.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_character.pth"), map_location=device, weights_only=True))
181
  mlp_character.to(device).to(dtype).eval()
182
 
183
+ artist_class = 17171
184
  mlp_artist = MLP(2048, artist_class)
185
  mlp_artist.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
186
  mlp_artist.to(device).to(dtype).eval()
187
 
 
 
 
 
188
  mlp_r = MLP_R(2048)
189
  mlp_r.load_state_dict(torch.load(os.path.join(repo, "retrieval_head.pth"), map_location=device, weights_only=True))
190
  mlp_r.to(device).to(dtype).eval()
 
224
 
225
  return general, character, artist, date, rating
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  def load_id_map(id_map_path):
228
  with open(id_map_path, "r") as f:
229
  id_map = json.load(f)
 
266
 
267
  return results
268
 
269
+ def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.1, timeout=2.0):
270
  pairs = []
271
  for item in retrieval_results:
272
  oid = item.get("original_id")
 
289
  url = "https:" + url
290
  elif url.startswith("/"):
291
  url = "https://danbooru.donmai.us" + url
 
 
292
 
293
+ dist = item.get("l2_distance")
294
+ pairs.append((url, oid, dist))
295
+ except Exception:
296
  pass
297
  finally:
298
 
 
321
 
322
  url_id_pairs = fetch_retrieval_image_urls(retrieval_results)
323
 
324
+ retrieval_gallery_items = [
325
+ (
326
+ url,
327
+ f"distance={dist:.3f} | id={oid}"
328
+ )
329
+ for url, oid, dist in url_id_pairs
330
+ ]
331
+
332
+ retrieval_links = "\n".join(
333
+ f"[id={oid}](https://danbooru.donmai.us/posts/{oid})"
334
+ for url, oid, dist in url_id_pairs
335
+ )
336
 
337
  general_prediction = mlp_general(embedding)
338
  general_ = prediction_to_tag(general_prediction, general_dict, general_class)
 
343
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
344
  character_tags = character_[1]
345
 
346
+ artist_prediction = mlp_artist(embedding)
347
+ artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
348
+ artist_tags = artist_[2]
349
+ date = artist_[3]
350
 
351
  combined_tags = {**general_tags}
352
 
 
369
  rating,
370
  date,
371
  retrieval_gallery_items,
372
+ retrieval_links,
373
  )
374
 
375
  def main():
 
384
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
385
  k_slider = gr.Slider(1, 100, value=32, step=1, label="Top K Results")
386
  distance_min_slider = gr.Slider(0, 128, value=0, step=1, label="Min Distance Threshold")
387
+ distance_max_slider = gr.Slider(0, 128, value=64, step=1, label="Max Distance Threshold")
388
  with gr.Row():
389
  clear = gr.ClearButton(
390
  components=[
 
410
  label="Retrieval Preview",
411
  columns=5,
412
  )
413
+ retrieval_links = gr.Markdown(label="Retrieval Links")
414
+
415
  clear.add(
416
  [
417
  tags_str,
 
421
  rating,
422
  date,
423
  retrieval_gallery,
424
+ retrieval_links,
425
  ]
426
  )
427
 
 
436
  rating,
437
  date,
438
  retrieval_gallery,
439
+ retrieval_links,
440
  ],
441
  )
442