dan_tagger / app.py
Johnny-Z's picture
Upload app.py
63e09a6 verified
from transformers import CLIPImageProcessor, AutoModel
import torch
import json
import torch.nn as nn
from PIL import Image
import gradio as gr
import os
from huggingface_hub import login, snapshot_download
TITLE = "Danbooru Tagger"
DESCRIPTION = """
## Dataset
- Source: Danbooru
- Cutoff Date: 2025-11-27
- Validation Split: 10% of Dataset
## Validation Results
### General
Tags Count: 11046
| Metric | Value |
|-----------------|-------------|
| Macro F1 | 0.4439 |
| Macro Precision | 0.4168 |
| Macro Recall | 0.4964 |
| Micro F1 | 0.6595 |
| Micro Precision | 0.5982 |
| Micro Recall | 0.7349 |
### Character
Tags Count: 9148
| Metric | Value |
|-----------------|-------------|
| Macro F1 | 0.8646 |
| Macro Precision | 0.8897 |
| Macro Recall | 0.8492 |
| Micro F1 | 0.9092 |
| Micro Precision | 0.9195 |
| Micro Recall | 0.8991 |
### Artist
Tags Count: 17171
| Metric | Value |
|-----------------|-------------|
| Macro F1 | 0.8008 |
| Macro Precision | 0.8669 |
| Macro Recall | 0.7641 |
| Micro F1 | 0.8596 |
| Micro Precision | 0.8948 |
| Micro Recall | 0.8271 |
"""
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
device = torch.device('cpu')
dtype = torch.float32
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
raise ValueError("environment variable HF_TOKEN not found.")
repo_id = "Johnny-Z/danbooru_vfm"
repo_dir = snapshot_download(repo_id)
model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device)
processor = CLIPImageProcessor.from_pretrained(repo_id)
class MultiheadAttentionPoolingHead(nn.Module):
def __init__(self, input_size):
super().__init__()
self.map_probe = nn.Parameter(torch.randn(1, 1, input_size))
self.map_layernorm0 = nn.LayerNorm(input_size, eps=1e-08)
self.map_attention = torch.nn.MultiheadAttention(input_size, input_size // 64, batch_first=True)
self.map_layernorm1 = nn.LayerNorm(input_size, eps=1e-08)
self.map_ffn = nn.Sequential(
nn.Linear(input_size, input_size * 4),
nn.SiLU(),
nn.Linear(input_size * 4, input_size)
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.map_probe.repeat(batch_size, 1, 1)
hidden_state = self.map_layernorm0(hidden_state)
hidden_state = self.map_attention(probe, hidden_state, hidden_state)[0]
hidden_state = self.map_layernorm1(hidden_state)
residual = hidden_state
hidden_state = residual + self.map_ffn(hidden_state)
return hidden_state[:, 0]
class MLP(nn.Module):
def __init__(self, input_size, class_num):
super().__init__()
self.mlp_layer0 = nn.Sequential(
nn.LayerNorm(input_size, eps=1e-08),
nn.Linear(input_size, input_size // 2),
nn.SiLU()
)
self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.mlp_layer0(x)
x = self.mlp_layer1(x)
x = self.sigmoid(x)
return x
with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
general_dict = json.load(f)
with open(os.path.join(repo_dir, 'character_tag_dict.json'), 'r', encoding='utf-8') as f:
character_dict = json.load(f)
with open(os.path.join(repo_dir, 'artist_tag_dict.json'), 'r', encoding='utf-8') as f:
artist_dict = json.load(f)
with open(os.path.join(repo_dir, 'implications_list.json'), 'r', encoding='utf-8') as f:
implications_list = json.load(f)
with open(os.path.join(repo_dir, 'artist_threshold.json'), 'r', encoding='utf-8') as f:
artist_thresholds = json.load(f)
with open(os.path.join(repo_dir, 'character_threshold.json'), 'r', encoding='utf-8') as f:
character_thresholds = json.load(f)
with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f:
general_thresholds = json.load(f)
with open(os.path.join(repo_dir, 'character_feature.json'), 'r', encoding='utf-8') as f:
character_features = json.load(f)
model_map = MultiheadAttentionPoolingHead(2048)
model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
model_map.to(device).to(dtype).eval()
general_class = 11046
mlp_general = MLP(2048, general_class)
mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True))
mlp_general.to(device).to(dtype).eval()
character_class = 9148
mlp_character = MLP(2048, character_class)
mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True))
mlp_character.to(device).to(dtype).eval()
artist_class = 17171
mlp_artist = MLP(2048, artist_class)
mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
mlp_artist.to(device).to(dtype).eval()
def prediction_to_tag(prediction, tag_dict, class_num):
prediction = prediction.view(class_num)
predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
general = {}
character = {}
artist = {}
date = {}
rating = {}
for tag, value in tag_dict.items():
if value[2] in predicted_ids:
tag_value = round(prediction[value[2] - 1].item(), 6)
if value[1] == "general" and tag_value >= general_thresholds.get(tag, {}).get("Threshold", 0.75):
general[tag] = tag_value
elif value[1] == "character" and tag_value >= character_thresholds.get(tag, {}).get("Threshold", 0.75):
character[tag] = tag_value
elif value[1] == "artist" and tag_value >= artist_thresholds.get(tag, {}).get("Threshold", 0.75):
artist[tag] = tag_value
elif value[1] == "rating":
rating[tag] = tag_value
elif value[1] == "date":
date[tag] = tag_value
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
if date:
date = {max(date, key=date.get): date[max(date, key=date.get)]}
if rating:
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
return general, character, artist, date, rating
def process_image(image):
try:
image = image.convert('RGBA')
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background, image).convert('RGB')
image_inputs = processor(images=[image], return_tensors="pt").to(device).to(dtype)
except (OSError, IOError) as e:
print(f"Error opening image: {e}")
return
with torch.no_grad():
embedding = model(image_inputs.pixel_values)
embedding = model_map(embedding)
general_prediction = mlp_general(embedding)
general_ = prediction_to_tag(general_prediction, general_dict, general_class)
general_tags = general_[0]
rating = general_[4]
character_prediction = mlp_character(embedding)
character_ = prediction_to_tag(character_prediction, character_dict, character_class)
character_tags = character_[1]
remove_list = []
for tag in character_tags:
if tag in implications_list:
remove_list.extend([implication for implication in implications_list[tag]])
character_tags_list = [tag for tag in character_tags if tag not in remove_list]
artist_prediction = mlp_artist(embedding)
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
artist_tags = artist_[2]
date = artist_[3]
combined_tags = {**general_tags}
tags_list = [tag for tag in combined_tags]
remove_list = []
for tag in tags_list:
if tag in implications_list:
for implication in implications_list[tag]:
remove_list.append(implication)
for char_tag in character_tags_list:
if char_tag in character_features:
for character_feature in character_features[char_tag]:
remove_list.append(character_feature)
tags_list = [tag for tag in tags_list if tag not in remove_list]
tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
tags_str = ", ".join(character_tags_list + tags_list).replace("(", r"\(").replace(")", r"\)")
return tags_str, artist_tags, character_tags, general_tags, rating, date
def main():
with gr.Blocks(title=TITLE) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
with gr.Row():
with gr.Column(variant="panel"):
submit = gr.Button(value="Submit", variant="primary", size="lg")
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
with gr.Row():
clear = gr.ClearButton(
components=[
image,
],
variant="secondary",
size="lg",
)
gr.Markdown(value=DESCRIPTION)
with gr.Column(variant="panel"):
tags_str = gr.Textbox(label="Output", lines=4)
with gr.Row():
rating = gr.Label(label="Rating")
date = gr.Label(label="Year")
artist_tags = gr.Label(label="Artist")
character_tags = gr.Label(label="Character")
general_tags = gr.Label(label="General")
clear.add(
[
tags_str,
artist_tags,
general_tags,
character_tags,
rating,
date,
]
)
submit.click(
process_image,
inputs=[
image
],
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
)
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main()