from transformers import CLIPImageProcessor, AutoModel import torch import json import torch.nn as nn from PIL import Image import gradio as gr import os from huggingface_hub import login, snapshot_download TITLE = "Danbooru Tagger" DESCRIPTION = """ ## Dataset - Source: Danbooru - Cutoff Date: 2025-11-27 - Validation Split: 10% of Dataset ## Validation Results ### General Tags Count: 11046 | Metric | Value | |-----------------|-------------| | Macro F1 | 0.4439 | | Macro Precision | 0.4168 | | Macro Recall | 0.4964 | | Micro F1 | 0.6595 | | Micro Precision | 0.5982 | | Micro Recall | 0.7349 | ### Character Tags Count: 9148 | Metric | Value | |-----------------|-------------| | Macro F1 | 0.8646 | | Macro Precision | 0.8897 | | Macro Recall | 0.8492 | | Micro F1 | 0.9092 | | Micro Precision | 0.9195 | | Micro Recall | 0.8991 | ### Artist Tags Count: 17171 | Metric | Value | |-----------------|-------------| | Macro F1 | 0.8008 | | Macro Precision | 0.8669 | | Macro Recall | 0.7641 | | Micro F1 | 0.8596 | | Micro Precision | 0.8948 | | Micro Recall | 0.8271 | """ kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] device = torch.device('cpu') dtype = torch.float32 hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) else: raise ValueError("environment variable HF_TOKEN not found.") repo_id = "Johnny-Z/danbooru_vfm" repo_dir = snapshot_download(repo_id) model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device) processor = CLIPImageProcessor.from_pretrained(repo_id) class MultiheadAttentionPoolingHead(nn.Module): def __init__(self, input_size): super().__init__() self.map_probe = nn.Parameter(torch.randn(1, 1, input_size)) self.map_layernorm0 = nn.LayerNorm(input_size, eps=1e-08) self.map_attention = torch.nn.MultiheadAttention(input_size, input_size // 64, batch_first=True) self.map_layernorm1 = nn.LayerNorm(input_size, eps=1e-08) self.map_ffn = nn.Sequential( nn.Linear(input_size, input_size * 4), nn.SiLU(), nn.Linear(input_size * 4, input_size) ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.map_probe.repeat(batch_size, 1, 1) hidden_state = self.map_layernorm0(hidden_state) hidden_state = self.map_attention(probe, hidden_state, hidden_state)[0] hidden_state = self.map_layernorm1(hidden_state) residual = hidden_state hidden_state = residual + self.map_ffn(hidden_state) return hidden_state[:, 0] class MLP(nn.Module): def __init__(self, input_size, class_num): super().__init__() self.mlp_layer0 = nn.Sequential( nn.LayerNorm(input_size, eps=1e-08), nn.Linear(input_size, input_size // 2), nn.SiLU() ) self.mlp_layer1 = nn.Linear(input_size // 2, class_num) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.mlp_layer0(x) x = self.mlp_layer1(x) x = self.sigmoid(x) return x with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f: general_dict = json.load(f) with open(os.path.join(repo_dir, 'character_tag_dict.json'), 'r', encoding='utf-8') as f: character_dict = json.load(f) with open(os.path.join(repo_dir, 'artist_tag_dict.json'), 'r', encoding='utf-8') as f: artist_dict = json.load(f) with open(os.path.join(repo_dir, 'implications_list.json'), 'r', encoding='utf-8') as f: implications_list = json.load(f) with open(os.path.join(repo_dir, 'artist_threshold.json'), 'r', encoding='utf-8') as f: artist_thresholds = json.load(f) with open(os.path.join(repo_dir, 'character_threshold.json'), 'r', encoding='utf-8') as f: character_thresholds = json.load(f) with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f: general_thresholds = json.load(f) with open(os.path.join(repo_dir, 'character_feature.json'), 'r', encoding='utf-8') as f: character_features = json.load(f) model_map = MultiheadAttentionPoolingHead(2048) model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True)) model_map.to(device).to(dtype).eval() general_class = 11046 mlp_general = MLP(2048, general_class) mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True)) mlp_general.to(device).to(dtype).eval() character_class = 9148 mlp_character = MLP(2048, character_class) mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True)) mlp_character.to(device).to(dtype).eval() artist_class = 17171 mlp_artist = MLP(2048, artist_class) mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True)) mlp_artist.to(device).to(dtype).eval() def prediction_to_tag(prediction, tag_dict, class_num): prediction = prediction.view(class_num) predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1 general = {} character = {} artist = {} date = {} rating = {} for tag, value in tag_dict.items(): if value[2] in predicted_ids: tag_value = round(prediction[value[2] - 1].item(), 6) if value[1] == "general" and tag_value >= general_thresholds.get(tag, {}).get("Threshold", 0.75): general[tag] = tag_value elif value[1] == "character" and tag_value >= character_thresholds.get(tag, {}).get("Threshold", 0.75): character[tag] = tag_value elif value[1] == "artist" and tag_value >= artist_thresholds.get(tag, {}).get("Threshold", 0.75): artist[tag] = tag_value elif value[1] == "rating": rating[tag] = tag_value elif value[1] == "date": date[tag] = tag_value general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True)) character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True)) artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True)) if date: date = {max(date, key=date.get): date[max(date, key=date.get)]} if rating: rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]} return general, character, artist, date, rating def process_image(image): try: image = image.convert('RGBA') background = Image.new('RGBA', image.size, (255, 255, 255, 255)) image = Image.alpha_composite(background, image).convert('RGB') image_inputs = processor(images=[image], return_tensors="pt").to(device).to(dtype) except (OSError, IOError) as e: print(f"Error opening image: {e}") return with torch.no_grad(): embedding = model(image_inputs.pixel_values) embedding = model_map(embedding) general_prediction = mlp_general(embedding) general_ = prediction_to_tag(general_prediction, general_dict, general_class) general_tags = general_[0] rating = general_[4] character_prediction = mlp_character(embedding) character_ = prediction_to_tag(character_prediction, character_dict, character_class) character_tags = character_[1] remove_list = [] for tag in character_tags: if tag in implications_list: remove_list.extend([implication for implication in implications_list[tag]]) character_tags_list = [tag for tag in character_tags if tag not in remove_list] artist_prediction = mlp_artist(embedding) artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class) artist_tags = artist_[2] date = artist_[3] combined_tags = {**general_tags} tags_list = [tag for tag in combined_tags] remove_list = [] for tag in tags_list: if tag in implications_list: for implication in implications_list[tag]: remove_list.append(implication) for char_tag in character_tags_list: if char_tag in character_features: for character_feature in character_features[char_tag]: remove_list.append(character_feature) tags_list = [tag for tag in tags_list if tag not in remove_list] tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list] tags_str = ", ".join(character_tags_list + tags_list).replace("(", r"\(").replace(")", r"\)") return tags_str, artist_tags, character_tags, general_tags, rating, date def main(): with gr.Blocks(title=TITLE) as demo: with gr.Column(): gr.Markdown( value=f"

{TITLE}

" ) with gr.Row(): with gr.Column(variant="panel"): submit = gr.Button(value="Submit", variant="primary", size="lg") image = gr.Image(type="pil", image_mode="RGBA", label="Input") with gr.Row(): clear = gr.ClearButton( components=[ image, ], variant="secondary", size="lg", ) gr.Markdown(value=DESCRIPTION) with gr.Column(variant="panel"): tags_str = gr.Textbox(label="Output", lines=4) with gr.Row(): rating = gr.Label(label="Rating") date = gr.Label(label="Year") artist_tags = gr.Label(label="Artist") character_tags = gr.Label(label="Character") general_tags = gr.Label(label="General") clear.add( [ tags_str, artist_tags, general_tags, character_tags, rating, date, ] ) submit.click( process_image, inputs=[ image ], outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date], ) demo.queue(max_size=10) demo.launch() if __name__ == "__main__": main()