File size: 6,528 Bytes
4cad170
 
 
 
b6b23d6
 
 
4cad170
 
15fbeb3
4cad170
b6b23d6
15fbeb3
 
b6b23d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad170
b6b23d6
 
 
 
 
 
 
 
 
 
4cad170
 
b6b23d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad170
 
b6b23d6
4cad170
b6b23d6
 
4cad170
 
b6b23d6
4cad170
b6b23d6
4cad170
b6b23d6
4cad170
b6b23d6
 
 
4cad170
b6b23d6
4cad170
b6b23d6
4cad170
b6b23d6
4cad170
b6b23d6
 
4cad170
b6b23d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad170
b6b23d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad170
b6b23d6
 
 
 
4cad170
b6b23d6
4cad170
b6b23d6
4cad170
b6b23d6
 
4cad170
b6b23d6
4cad170
b6b23d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad170
 
b6b23d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import torch
import timm
from PIL import Image
import json
from torchvision import transforms
from huggingface_hub import hf_hub_download

# --- 配置 ---
REPO_ID = "telecomadm1145/convnext_large.dinov3_tagger_2" 
MODEL_FILENAME = "pytorch_model.bin"
TAGS_FILENAME = "tag_map.json"
MODEL_NAME = "convnext_large.dinov3_lvd1689m"
INPUT_SIZE = (512,512)

# --- 1. 预处理 (Letterbox) ---
class LetterboxPad:
    def __init__(self, size, fill=(255, 255, 255)):
        self.size = size if isinstance(size, tuple) else (size, size)
        self.fill = fill

    def __call__(self, img):
        w, h = img.size
        target_h, target_w = self.size
        scale = min(target_w / w, target_h / h)
        new_w = int(w * scale)
        new_h = int(h * scale)
        img = img.resize((new_w, new_h), Image.BICUBIC)
        new_img = Image.new("RGB", (target_w, target_h), self.fill)
        paste_x = (target_w - new_w) // 2
        paste_y = (target_h - new_h) // 2
        new_img.paste(img, (paste_x, paste_y))
        return new_img

def build_transform(size):
    return transforms.Compose([
        LetterboxPad(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# --- 2. 加载资源与分组 ---
print("Loading model and tags...")
device = torch.device("cpu")

# 存储不同组的 (name, index) 列表
tag_groups = {
    "rating": [],
    "character": [],
    "general": []
}

try:
    json_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME)
    with open(json_path, 'r') as f:
        grouped_json = json.load(f)
    
    # 解析分组: 假设 JSON 结构为 {"rating": {"safe": 0, ...}, "general": ...}
    total_tags = 0
    for group_key, tags_dict in grouped_json.items():
        # 兼容处理:确保 key 是我们预期的,如果只有 standard tags 可能会归类到 general
        target_group = group_key if group_key in tag_groups else "general"
        
        for name, idx in tags_dict.items():
            tag_groups[target_group].append((name, int(idx)))
            total_tags += 1
            
    print(f"Loaded {total_tags} tags.")
    print(f" - Rating: {len(tag_groups['rating'])}")
    print(f" - Character: {len(tag_groups['character'])}")
    print(f" - General: {len(tag_groups['general'])}")
    
except Exception as e:
    print(f"Error loading tags: {e}")
    total_tags = 12000 # Fallback

# 加载模型
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=total_tags)
try:
    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    print("Model weights loaded.")
except Exception as e:
    print(f"Error loading weights: {e}")

model.to(device)
model.eval()
transform = build_transform(INPUT_SIZE)

# --- 3. 推理逻辑 ---
@torch.no_grad()
def predict(image, threshold_gen, threshold_char):
    if image is None:
        return {}, {}, {}
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    logits = model(img_tensor)[0] # Shape: [num_classes]
    
    # --- A. 处理 Rating (Softmax) ---
    rating_res = {}
    if tag_groups["rating"]:
        # 提取 rating 对应的 logits
        r_indices = [idx for _, idx in tag_groups["rating"]]
        r_names = [name for name, _ in tag_groups["rating"]]
        
        # 将 indices 转为 tensor 以便索引
        r_indices_tensor = torch.tensor(r_indices, device=device)
        r_logits = logits[r_indices_tensor]
        
        # 核心修改:对 Rating 组内进行 Softmax
        r_probs = torch.nn.functional.softmax(r_logits, dim=0)
        
        for name, prob in zip(r_names, r_probs):
            rating_res[name] = float(prob)
    
    # --- B. 处理 Character (Sigmoid + Threshold) ---
    char_res = {}
    if tag_groups["character"]:
        c_indices = [idx for _, idx in tag_groups["character"]]
        c_names = [name for name, _ in tag_groups["character"]]
        
        c_indices_tensor = torch.tensor(c_indices, device=device)
        c_logits = logits[c_indices_tensor]
        c_probs = torch.sigmoid(c_logits) # 多标签使用 Sigmoid
        
        for name, prob in zip(c_names, c_probs):
            if prob > threshold_char:
                char_res[name] = float(prob)

    # --- C. 处理 General (Sigmoid + Threshold) ---
    gen_res = {}
    if tag_groups["general"]:
        g_indices = [idx for _, idx in tag_groups["general"]]
        g_names = [name for name, _ in tag_groups["general"]]
        
        g_indices_tensor = torch.tensor(g_indices, device=device)
        g_logits = logits[g_indices_tensor]
        g_probs = torch.sigmoid(g_logits) # 多标签使用 Sigmoid
        
        for name, prob in zip(g_names, g_probs):
            if prob > threshold_gen:
                gen_res[name] = float(prob)
    
    # 排序
    rating_res = dict(sorted(rating_res.items(), key=lambda x: x[1], reverse=True))
    char_res = dict(sorted(char_res.items(), key=lambda x: x[1], reverse=True))
    gen_res = dict(sorted(gen_res.items(), key=lambda x: x[1], reverse=True))
    
    return rating_res, char_res, gen_res

# --- 4. 界面 ---
with gr.Blocks() as demo:
    gr.Markdown(f"# Anime Tagger (DINOv3)\nModel: {REPO_ID}")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_img = gr.Image(type="pil", label="Input Image")
            run_btn = gr.Button("Tag It!", variant="primary")
            
            gr.Markdown("### Thresholds")
            # 为不同类别设置不同的阈值通常更好,Character 往往需要更低的阈值来召回
            threshold_gen = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="General Tags Threshold")
            threshold_char = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Character Threshold")

        with gr.Column(scale=1):
            # 分开显示
            gr.Markdown("### 1. Rating (Softmax)")
            out_rating = gr.Label(label="Rating")
            
            gr.Markdown("### 2. Characters")
            out_char = gr.Label(label="Characters", num_top_classes=10)
            
            gr.Markdown("### 3. General Tags")
            out_gen = gr.Label(label="General Tags", num_top_classes=50)

    run_btn.click(
        fn=predict, 
        inputs=[input_img, threshold_gen, threshold_char], 
        outputs=[out_rating, out_char, out_gen]
    )

if __name__ == "__main__":
    demo.launch()