Spaces:
Sleeping
Sleeping
| from transformers import CLIPImageProcessor, AutoModel | |
| import torch | |
| import json | |
| import torch.nn as nn | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import login, snapshot_download | |
| TITLE = "Danbooru Tagger" | |
| DESCRIPTION = """ | |
| ## Dataset | |
| - Source: Danbooru | |
| - Cutoff Date: 2025-11-27 | |
| - Validation Split: 10% of Dataset | |
| ## Validation Results | |
| ### General | |
| Tags Count: 11046 | |
| | Metric | Value | | |
| |-----------------|-------------| | |
| | Macro F1 | 0.4439 | | |
| | Macro Precision | 0.4168 | | |
| | Macro Recall | 0.4964 | | |
| | Micro F1 | 0.6595 | | |
| | Micro Precision | 0.5982 | | |
| | Micro Recall | 0.7349 | | |
| ### Character | |
| Tags Count: 9148 | |
| | Metric | Value | | |
| |-----------------|-------------| | |
| | Macro F1 | 0.8646 | | |
| | Macro Precision | 0.8897 | | |
| | Macro Recall | 0.8492 | | |
| | Micro F1 | 0.9092 | | |
| | Micro Precision | 0.9195 | | |
| | Micro Recall | 0.8991 | | |
| ### Artist | |
| Tags Count: 17171 | |
| | Metric | Value | | |
| |-----------------|-------------| | |
| | Macro F1 | 0.8008 | | |
| | Macro Precision | 0.8669 | | |
| | Macro Recall | 0.7641 | | |
| | Micro F1 | 0.8596 | | |
| | Micro Precision | 0.8948 | | |
| | Micro Recall | 0.8271 | | |
| """ | |
| kaomojis = [ | |
| "0_0", | |
| "(o)_(o)", | |
| "+_+", | |
| "+_-", | |
| "._.", | |
| "<o>_<o>", | |
| "<|>_<|>", | |
| "=_=", | |
| ">_<", | |
| "3_3", | |
| "6_9", | |
| ">_o", | |
| "@_@", | |
| "^_^", | |
| "o_o", | |
| "u_u", | |
| "x_x", | |
| "|_|", | |
| "||_||", | |
| ] | |
| device = torch.device('cpu') | |
| dtype = torch.float32 | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| raise ValueError("environment variable HF_TOKEN not found.") | |
| repo_id = "Johnny-Z/danbooru_vfm" | |
| repo_dir = snapshot_download(repo_id) | |
| model = AutoModel.from_pretrained(repo_id, dtype=dtype, trust_remote_code=True, device_map=device) | |
| processor = CLIPImageProcessor.from_pretrained(repo_id) | |
| class MultiheadAttentionPoolingHead(nn.Module): | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.map_probe = nn.Parameter(torch.randn(1, 1, input_size)) | |
| self.map_layernorm0 = nn.LayerNorm(input_size, eps=1e-08) | |
| self.map_attention = torch.nn.MultiheadAttention(input_size, input_size // 64, batch_first=True) | |
| self.map_layernorm1 = nn.LayerNorm(input_size, eps=1e-08) | |
| self.map_ffn = nn.Sequential( | |
| nn.Linear(input_size, input_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(input_size * 4, input_size) | |
| ) | |
| def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: | |
| batch_size = hidden_state.shape[0] | |
| probe = self.map_probe.repeat(batch_size, 1, 1) | |
| hidden_state = self.map_layernorm0(hidden_state) | |
| hidden_state = self.map_attention(probe, hidden_state, hidden_state)[0] | |
| hidden_state = self.map_layernorm1(hidden_state) | |
| residual = hidden_state | |
| hidden_state = residual + self.map_ffn(hidden_state) | |
| return hidden_state[:, 0] | |
| class MLP(nn.Module): | |
| def __init__(self, input_size, class_num): | |
| super().__init__() | |
| self.mlp_layer0 = nn.Sequential( | |
| nn.LayerNorm(input_size, eps=1e-08), | |
| nn.Linear(input_size, input_size // 2), | |
| nn.SiLU() | |
| ) | |
| self.mlp_layer1 = nn.Linear(input_size // 2, class_num) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.mlp_layer0(x) | |
| x = self.mlp_layer1(x) | |
| x = self.sigmoid(x) | |
| return x | |
| with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f: | |
| general_dict = json.load(f) | |
| with open(os.path.join(repo_dir, 'character_tag_dict.json'), 'r', encoding='utf-8') as f: | |
| character_dict = json.load(f) | |
| with open(os.path.join(repo_dir, 'artist_tag_dict.json'), 'r', encoding='utf-8') as f: | |
| artist_dict = json.load(f) | |
| with open(os.path.join(repo_dir, 'implications_list.json'), 'r', encoding='utf-8') as f: | |
| implications_list = json.load(f) | |
| with open(os.path.join(repo_dir, 'artist_threshold.json'), 'r', encoding='utf-8') as f: | |
| artist_thresholds = json.load(f) | |
| with open(os.path.join(repo_dir, 'character_threshold.json'), 'r', encoding='utf-8') as f: | |
| character_thresholds = json.load(f) | |
| with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f: | |
| general_thresholds = json.load(f) | |
| with open(os.path.join(repo_dir, 'character_feature.json'), 'r', encoding='utf-8') as f: | |
| character_features = json.load(f) | |
| model_map = MultiheadAttentionPoolingHead(2048) | |
| model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True)) | |
| model_map.to(device).to(dtype).eval() | |
| general_class = 11046 | |
| mlp_general = MLP(2048, general_class) | |
| mlp_general.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_general.pth"), map_location=device, weights_only=True)) | |
| mlp_general.to(device).to(dtype).eval() | |
| character_class = 9148 | |
| mlp_character = MLP(2048, character_class) | |
| mlp_character.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_character.pth"), map_location=device, weights_only=True)) | |
| mlp_character.to(device).to(dtype).eval() | |
| artist_class = 17171 | |
| mlp_artist = MLP(2048, artist_class) | |
| mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True)) | |
| mlp_artist.to(device).to(dtype).eval() | |
| def prediction_to_tag(prediction, tag_dict, class_num): | |
| prediction = prediction.view(class_num) | |
| predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1 | |
| general = {} | |
| character = {} | |
| artist = {} | |
| date = {} | |
| rating = {} | |
| for tag, value in tag_dict.items(): | |
| if value[2] in predicted_ids: | |
| tag_value = round(prediction[value[2] - 1].item(), 6) | |
| if value[1] == "general" and tag_value >= general_thresholds.get(tag, {}).get("Threshold", 0.75): | |
| general[tag] = tag_value | |
| elif value[1] == "character" and tag_value >= character_thresholds.get(tag, {}).get("Threshold", 0.75): | |
| character[tag] = tag_value | |
| elif value[1] == "artist" and tag_value >= artist_thresholds.get(tag, {}).get("Threshold", 0.75): | |
| artist[tag] = tag_value | |
| elif value[1] == "rating": | |
| rating[tag] = tag_value | |
| elif value[1] == "date": | |
| date[tag] = tag_value | |
| general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True)) | |
| character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True)) | |
| artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True)) | |
| if date: | |
| date = {max(date, key=date.get): date[max(date, key=date.get)]} | |
| if rating: | |
| rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]} | |
| return general, character, artist, date, rating | |
| def process_image(image): | |
| try: | |
| image = image.convert('RGBA') | |
| background = Image.new('RGBA', image.size, (255, 255, 255, 255)) | |
| image = Image.alpha_composite(background, image).convert('RGB') | |
| image_inputs = processor(images=[image], return_tensors="pt").to(device).to(dtype) | |
| except (OSError, IOError) as e: | |
| print(f"Error opening image: {e}") | |
| return | |
| with torch.no_grad(): | |
| embedding = model(image_inputs.pixel_values) | |
| embedding = model_map(embedding) | |
| general_prediction = mlp_general(embedding) | |
| general_ = prediction_to_tag(general_prediction, general_dict, general_class) | |
| general_tags = general_[0] | |
| rating = general_[4] | |
| character_prediction = mlp_character(embedding) | |
| character_ = prediction_to_tag(character_prediction, character_dict, character_class) | |
| character_tags = character_[1] | |
| remove_list = [] | |
| for tag in character_tags: | |
| if tag in implications_list: | |
| remove_list.extend([implication for implication in implications_list[tag]]) | |
| character_tags_list = [tag for tag in character_tags if tag not in remove_list] | |
| artist_prediction = mlp_artist(embedding) | |
| artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class) | |
| artist_tags = artist_[2] | |
| date = artist_[3] | |
| combined_tags = {**general_tags} | |
| tags_list = [tag for tag in combined_tags] | |
| remove_list = [] | |
| for tag in tags_list: | |
| if tag in implications_list: | |
| for implication in implications_list[tag]: | |
| remove_list.append(implication) | |
| for char_tag in character_tags_list: | |
| if char_tag in character_features: | |
| for character_feature in character_features[char_tag]: | |
| remove_list.append(character_feature) | |
| tags_list = [tag for tag in tags_list if tag not in remove_list] | |
| tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list] | |
| tags_str = ", ".join(character_tags_list + tags_list).replace("(", r"\(").replace(")", r"\)") | |
| return tags_str, artist_tags, character_tags, general_tags, rating, date | |
| def main(): | |
| with gr.Blocks(title=TITLE) as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| submit = gr.Button(value="Submit", variant="primary", size="lg") | |
| image = gr.Image(type="pil", image_mode="RGBA", label="Input") | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[ | |
| image, | |
| ], | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| gr.Markdown(value=DESCRIPTION) | |
| with gr.Column(variant="panel"): | |
| tags_str = gr.Textbox(label="Output", lines=4) | |
| with gr.Row(): | |
| rating = gr.Label(label="Rating") | |
| date = gr.Label(label="Year") | |
| artist_tags = gr.Label(label="Artist") | |
| character_tags = gr.Label(label="Character") | |
| general_tags = gr.Label(label="General") | |
| clear.add( | |
| [ | |
| tags_str, | |
| artist_tags, | |
| general_tags, | |
| character_tags, | |
| rating, | |
| date, | |
| ] | |
| ) | |
| submit.click( | |
| process_image, | |
| inputs=[ | |
| image | |
| ], | |
| outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date], | |
| ) | |
| demo.queue(max_size=10) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |