Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -13,41 +13,44 @@ from huggingface_hub import login, snapshot_download
|
|
| 13 |
TITLE = "Danbooru Tagger"
|
| 14 |
DESCRIPTION = """
|
| 15 |
## Dataset
|
| 16 |
-
- Source:
|
| 17 |
-
|
| 18 |
-
## Metrics
|
| 19 |
- Validation Split: 10% of Dataset
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
### General
|
|
|
|
| 23 |
| Metric | Value |
|
| 24 |
|-----------------|-------------|
|
| 25 |
-
| Macro F1 | 0.
|
| 26 |
-
| Macro Precision | 0.
|
| 27 |
-
| Macro Recall | 0.
|
| 28 |
-
| Micro F1 | 0.
|
| 29 |
-
| Micro Precision | 0.
|
| 30 |
-
| Micro Recall | 0.
|
| 31 |
|
| 32 |
### Character
|
|
|
|
| 33 |
| Metric | Value |
|
| 34 |
|-----------------|-------------|
|
| 35 |
-
| Macro F1 | 0.
|
| 36 |
-
| Macro Precision | 0.
|
| 37 |
-
| Macro Recall | 0.
|
| 38 |
-
| Micro F1 | 0.
|
| 39 |
-
| Micro Precision | 0.
|
| 40 |
-
| Micro Recall | 0.
|
| 41 |
|
| 42 |
### Artist
|
|
|
|
| 43 |
| Metric | Value |
|
| 44 |
|-----------------|-------------|
|
| 45 |
-
| Macro F1 | 0.
|
| 46 |
-
| Macro Precision | 0.
|
| 47 |
-
| Macro Recall | 0.
|
| 48 |
-
| Micro F1 | 0.
|
| 49 |
-
| Micro Precision | 0.
|
| 50 |
-
| Micro Recall | 0.
|
| 51 |
"""
|
| 52 |
|
| 53 |
kaomojis = [
|
|
@@ -81,10 +84,10 @@ if hf_token:
|
|
| 81 |
else:
|
| 82 |
raise ValueError("environment variable HF_TOKEN not found.")
|
| 83 |
|
| 84 |
-
repo = snapshot_download('Johnny-Z/
|
| 85 |
model = AutoModel.from_pretrained(repo, dtype=dtype, trust_remote_code=True, device_map=device)
|
| 86 |
|
| 87 |
-
index_dir = snapshot_download('Johnny-Z/
|
| 88 |
|
| 89 |
processor = CLIPImageProcessor.from_pretrained(repo)
|
| 90 |
|
|
@@ -131,30 +134,11 @@ class MLP(nn.Module):
|
|
| 131 |
x = self.sigmoid(x)
|
| 132 |
return x
|
| 133 |
|
| 134 |
-
class MLP_Retrieval(nn.Module):
|
| 135 |
-
def __init__(self, input_size, class_num):
|
| 136 |
-
super().__init__()
|
| 137 |
-
self.mlp_layer0 = nn.Sequential(
|
| 138 |
-
nn.Linear(input_size, input_size // 2),
|
| 139 |
-
nn.SiLU()
|
| 140 |
-
)
|
| 141 |
-
self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
|
| 142 |
-
|
| 143 |
-
def forward(self, x):
|
| 144 |
-
x = self.mlp_layer0(x)
|
| 145 |
-
x = self.mlp_layer1(x)
|
| 146 |
-
x1, x2 = x[:, :15], x[:, 15:]
|
| 147 |
-
x1 = torch.softmax(x1, dim=1)
|
| 148 |
-
x2 = torch.softmax(x2, dim=1)
|
| 149 |
-
x = torch.cat([x1, x2], dim=1)
|
| 150 |
-
|
| 151 |
-
return x
|
| 152 |
-
|
| 153 |
class MLP_R(nn.Module):
|
| 154 |
def __init__(self, input_size):
|
| 155 |
super().__init__()
|
| 156 |
self.mlp_layer0 = nn.Sequential(
|
| 157 |
-
nn.Linear(input_size,
|
| 158 |
)
|
| 159 |
|
| 160 |
def forward(self, x):
|
|
@@ -186,25 +170,21 @@ model_map = MultiheadAttentionPoolingHead(2048)
|
|
| 186 |
model_map.load_state_dict(torch.load(os.path.join(repo, "map_head.pth"), map_location=device, weights_only=True))
|
| 187 |
model_map.to(device).to(dtype).eval()
|
| 188 |
|
| 189 |
-
general_class =
|
| 190 |
mlp_general = MLP(2048, general_class)
|
| 191 |
mlp_general.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_general.pth"), map_location=device, weights_only=True))
|
| 192 |
mlp_general.to(device).to(dtype).eval()
|
| 193 |
|
| 194 |
-
character_class =
|
| 195 |
mlp_character = MLP(2048, character_class)
|
| 196 |
mlp_character.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_character.pth"), map_location=device, weights_only=True))
|
| 197 |
mlp_character.to(device).to(dtype).eval()
|
| 198 |
|
| 199 |
-
artist_class =
|
| 200 |
mlp_artist = MLP(2048, artist_class)
|
| 201 |
mlp_artist.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
|
| 202 |
mlp_artist.to(device).to(dtype).eval()
|
| 203 |
|
| 204 |
-
mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
|
| 205 |
-
mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
|
| 206 |
-
mlp_artist_retrieval.to(device).to(dtype).eval()
|
| 207 |
-
|
| 208 |
mlp_r = MLP_R(2048)
|
| 209 |
mlp_r.load_state_dict(torch.load(os.path.join(repo, "retrieval_head.pth"), map_location=device, weights_only=True))
|
| 210 |
mlp_r.to(device).to(dtype).eval()
|
|
@@ -244,29 +224,6 @@ def prediction_to_tag(prediction, tag_dict, class_num):
|
|
| 244 |
|
| 245 |
return general, character, artist, date, rating
|
| 246 |
|
| 247 |
-
def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
|
| 248 |
-
prediction = prediction.view(class_num)
|
| 249 |
-
predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
| 250 |
-
|
| 251 |
-
artist = {}
|
| 252 |
-
date = {}
|
| 253 |
-
|
| 254 |
-
for tag, value in tag_dict.items():
|
| 255 |
-
if value[2] in predicted_ids:
|
| 256 |
-
tag_value = round(prediction[value[2] - 1].item(), 6)
|
| 257 |
-
if value[1] == "artist":
|
| 258 |
-
artist[tag] = tag_value
|
| 259 |
-
elif value[1] == "date":
|
| 260 |
-
date[tag] = tag_value
|
| 261 |
-
|
| 262 |
-
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
|
| 263 |
-
artist = dict(list(artist.items())[:top_k])
|
| 264 |
-
|
| 265 |
-
if date:
|
| 266 |
-
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
| 267 |
-
|
| 268 |
-
return artist, date
|
| 269 |
-
|
| 270 |
def load_id_map(id_map_path):
|
| 271 |
with open(id_map_path, "r") as f:
|
| 272 |
id_map = json.load(f)
|
|
@@ -309,7 +266,7 @@ def search_index(query_vector, k=32, distance_threshold_min=0, distance_threshol
|
|
| 309 |
|
| 310 |
return results
|
| 311 |
|
| 312 |
-
def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.
|
| 313 |
pairs = []
|
| 314 |
for item in retrieval_results:
|
| 315 |
oid = item.get("original_id")
|
|
@@ -332,9 +289,10 @@ def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.25, timeout=4.0):
|
|
| 332 |
url = "https:" + url
|
| 333 |
elif url.startswith("/"):
|
| 334 |
url = "https://danbooru.donmai.us" + url
|
| 335 |
-
pairs.append((url, oid))
|
| 336 |
-
except Exception:
|
| 337 |
|
|
|
|
|
|
|
|
|
|
| 338 |
pass
|
| 339 |
finally:
|
| 340 |
|
|
@@ -363,7 +321,18 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
|
|
| 363 |
|
| 364 |
url_id_pairs = fetch_retrieval_image_urls(retrieval_results)
|
| 365 |
|
| 366 |
-
retrieval_gallery_items = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
general_prediction = mlp_general(embedding)
|
| 369 |
general_ = prediction_to_tag(general_prediction, general_dict, general_class)
|
|
@@ -374,10 +343,10 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
|
|
| 374 |
character_ = prediction_to_tag(character_prediction, character_dict, character_class)
|
| 375 |
character_tags = character_[1]
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
artist_tags =
|
| 380 |
-
date =
|
| 381 |
|
| 382 |
combined_tags = {**general_tags}
|
| 383 |
|
|
@@ -400,6 +369,7 @@ def process_image(image, k, distance_threshold_min, distance_threshold_max):
|
|
| 400 |
rating,
|
| 401 |
date,
|
| 402 |
retrieval_gallery_items,
|
|
|
|
| 403 |
)
|
| 404 |
|
| 405 |
def main():
|
|
@@ -414,7 +384,7 @@ def main():
|
|
| 414 |
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
|
| 415 |
k_slider = gr.Slider(1, 100, value=32, step=1, label="Top K Results")
|
| 416 |
distance_min_slider = gr.Slider(0, 128, value=0, step=1, label="Min Distance Threshold")
|
| 417 |
-
distance_max_slider = gr.Slider(0, 128, value=
|
| 418 |
with gr.Row():
|
| 419 |
clear = gr.ClearButton(
|
| 420 |
components=[
|
|
@@ -440,6 +410,8 @@ def main():
|
|
| 440 |
label="Retrieval Preview",
|
| 441 |
columns=5,
|
| 442 |
)
|
|
|
|
|
|
|
| 443 |
clear.add(
|
| 444 |
[
|
| 445 |
tags_str,
|
|
@@ -449,6 +421,7 @@ def main():
|
|
| 449 |
rating,
|
| 450 |
date,
|
| 451 |
retrieval_gallery,
|
|
|
|
| 452 |
]
|
| 453 |
)
|
| 454 |
|
|
@@ -463,6 +436,7 @@ def main():
|
|
| 463 |
rating,
|
| 464 |
date,
|
| 465 |
retrieval_gallery,
|
|
|
|
| 466 |
],
|
| 467 |
)
|
| 468 |
|
|
|
|
| 13 |
TITLE = "Danbooru Tagger"
|
| 14 |
DESCRIPTION = """
|
| 15 |
## Dataset
|
| 16 |
+
- Source: Danbooru
|
| 17 |
+
- Cutoff Date: 2025-11-27
|
|
|
|
| 18 |
- Validation Split: 10% of Dataset
|
| 19 |
+
|
| 20 |
+
## Validation Results
|
| 21 |
|
| 22 |
### General
|
| 23 |
+
Tags Count: 11046
|
| 24 |
| Metric | Value |
|
| 25 |
|-----------------|-------------|
|
| 26 |
+
| Macro F1 | 0.4439 |
|
| 27 |
+
| Macro Precision | 0.4168 |
|
| 28 |
+
| Macro Recall | 0.4964 |
|
| 29 |
+
| Micro F1 | 0.6595 |
|
| 30 |
+
| Micro Precision | 0.5982 |
|
| 31 |
+
| Micro Recall | 0.7349 |
|
| 32 |
|
| 33 |
### Character
|
| 34 |
+
Tags Count: 9148
|
| 35 |
| Metric | Value |
|
| 36 |
|-----------------|-------------|
|
| 37 |
+
| Macro F1 | 0.8646 |
|
| 38 |
+
| Macro Precision | 0.8897 |
|
| 39 |
+
| Macro Recall | 0.8492 |
|
| 40 |
+
| Micro F1 | 0.9092 |
|
| 41 |
+
| Micro Precision | 0.9195 |
|
| 42 |
+
| Micro Recall | 0.8991 |
|
| 43 |
|
| 44 |
### Artist
|
| 45 |
+
Tags Count: 17171
|
| 46 |
| Metric | Value |
|
| 47 |
|-----------------|-------------|
|
| 48 |
+
| Macro F1 | 0.8008 |
|
| 49 |
+
| Macro Precision | 0.8669 |
|
| 50 |
+
| Macro Recall | 0.7641 |
|
| 51 |
+
| Micro F1 | 0.8596 |
|
| 52 |
+
| Micro Precision | 0.8948 |
|
| 53 |
+
| Micro Recall | 0.8271 |
|
| 54 |
"""
|
| 55 |
|
| 56 |
kaomojis = [
|
|
|
|
| 84 |
else:
|
| 85 |
raise ValueError("environment variable HF_TOKEN not found.")
|
| 86 |
|
| 87 |
+
repo = snapshot_download('Johnny-Z/danbooru_vfm')
|
| 88 |
model = AutoModel.from_pretrained(repo, dtype=dtype, trust_remote_code=True, device_map=device)
|
| 89 |
|
| 90 |
+
index_dir = snapshot_download('Johnny-Z/dan_index_a', repo_type='dataset')
|
| 91 |
|
| 92 |
processor = CLIPImageProcessor.from_pretrained(repo)
|
| 93 |
|
|
|
|
| 134 |
x = self.sigmoid(x)
|
| 135 |
return x
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
class MLP_R(nn.Module):
|
| 138 |
def __init__(self, input_size):
|
| 139 |
super().__init__()
|
| 140 |
self.mlp_layer0 = nn.Sequential(
|
| 141 |
+
nn.Linear(input_size, 384),
|
| 142 |
)
|
| 143 |
|
| 144 |
def forward(self, x):
|
|
|
|
| 170 |
model_map.load_state_dict(torch.load(os.path.join(repo, "map_head.pth"), map_location=device, weights_only=True))
|
| 171 |
model_map.to(device).to(dtype).eval()
|
| 172 |
|
| 173 |
+
general_class = 11046
|
| 174 |
mlp_general = MLP(2048, general_class)
|
| 175 |
mlp_general.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_general.pth"), map_location=device, weights_only=True))
|
| 176 |
mlp_general.to(device).to(dtype).eval()
|
| 177 |
|
| 178 |
+
character_class = 9148
|
| 179 |
mlp_character = MLP(2048, character_class)
|
| 180 |
mlp_character.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_character.pth"), map_location=device, weights_only=True))
|
| 181 |
mlp_character.to(device).to(dtype).eval()
|
| 182 |
|
| 183 |
+
artist_class = 17171
|
| 184 |
mlp_artist = MLP(2048, artist_class)
|
| 185 |
mlp_artist.load_state_dict(torch.load(os.path.join(repo, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
|
| 186 |
mlp_artist.to(device).to(dtype).eval()
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
mlp_r = MLP_R(2048)
|
| 189 |
mlp_r.load_state_dict(torch.load(os.path.join(repo, "retrieval_head.pth"), map_location=device, weights_only=True))
|
| 190 |
mlp_r.to(device).to(dtype).eval()
|
|
|
|
| 224 |
|
| 225 |
return general, character, artist, date, rating
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def load_id_map(id_map_path):
|
| 228 |
with open(id_map_path, "r") as f:
|
| 229 |
id_map = json.load(f)
|
|
|
|
| 266 |
|
| 267 |
return results
|
| 268 |
|
| 269 |
+
def fetch_retrieval_image_urls(retrieval_results, sleep_sec=0.1, timeout=2.0):
|
| 270 |
pairs = []
|
| 271 |
for item in retrieval_results:
|
| 272 |
oid = item.get("original_id")
|
|
|
|
| 289 |
url = "https:" + url
|
| 290 |
elif url.startswith("/"):
|
| 291 |
url = "https://danbooru.donmai.us" + url
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
dist = item.get("l2_distance")
|
| 294 |
+
pairs.append((url, oid, dist))
|
| 295 |
+
except Exception:
|
| 296 |
pass
|
| 297 |
finally:
|
| 298 |
|
|
|
|
| 321 |
|
| 322 |
url_id_pairs = fetch_retrieval_image_urls(retrieval_results)
|
| 323 |
|
| 324 |
+
retrieval_gallery_items = [
|
| 325 |
+
(
|
| 326 |
+
url,
|
| 327 |
+
f"distance={dist:.3f} | id={oid}"
|
| 328 |
+
)
|
| 329 |
+
for url, oid, dist in url_id_pairs
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
retrieval_links = "\n".join(
|
| 333 |
+
f"[id={oid}](https://danbooru.donmai.us/posts/{oid})"
|
| 334 |
+
for url, oid, dist in url_id_pairs
|
| 335 |
+
)
|
| 336 |
|
| 337 |
general_prediction = mlp_general(embedding)
|
| 338 |
general_ = prediction_to_tag(general_prediction, general_dict, general_class)
|
|
|
|
| 343 |
character_ = prediction_to_tag(character_prediction, character_dict, character_class)
|
| 344 |
character_tags = character_[1]
|
| 345 |
|
| 346 |
+
artist_prediction = mlp_artist(embedding)
|
| 347 |
+
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
|
| 348 |
+
artist_tags = artist_[2]
|
| 349 |
+
date = artist_[3]
|
| 350 |
|
| 351 |
combined_tags = {**general_tags}
|
| 352 |
|
|
|
|
| 369 |
rating,
|
| 370 |
date,
|
| 371 |
retrieval_gallery_items,
|
| 372 |
+
retrieval_links,
|
| 373 |
)
|
| 374 |
|
| 375 |
def main():
|
|
|
|
| 384 |
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
|
| 385 |
k_slider = gr.Slider(1, 100, value=32, step=1, label="Top K Results")
|
| 386 |
distance_min_slider = gr.Slider(0, 128, value=0, step=1, label="Min Distance Threshold")
|
| 387 |
+
distance_max_slider = gr.Slider(0, 128, value=64, step=1, label="Max Distance Threshold")
|
| 388 |
with gr.Row():
|
| 389 |
clear = gr.ClearButton(
|
| 390 |
components=[
|
|
|
|
| 410 |
label="Retrieval Preview",
|
| 411 |
columns=5,
|
| 412 |
)
|
| 413 |
+
retrieval_links = gr.Markdown(label="Retrieval Links")
|
| 414 |
+
|
| 415 |
clear.add(
|
| 416 |
[
|
| 417 |
tags_str,
|
|
|
|
| 421 |
rating,
|
| 422 |
date,
|
| 423 |
retrieval_gallery,
|
| 424 |
+
retrieval_links,
|
| 425 |
]
|
| 426 |
)
|
| 427 |
|
|
|
|
| 436 |
rating,
|
| 437 |
date,
|
| 438 |
retrieval_gallery,
|
| 439 |
+
retrieval_links,
|
| 440 |
],
|
| 441 |
)
|
| 442 |
|