import gradio as gr import torch import timm from PIL import Image import json from torchvision import transforms from huggingface_hub import hf_hub_download # --- 配置 --- REPO_ID = "telecomadm1145/convnext_large.dinov3_tagger_2" MODEL_FILENAME = "pytorch_model.bin" TAGS_FILENAME = "tag_map.json" MODEL_NAME = "convnext_large.dinov3_lvd1689m" INPUT_SIZE = (512,512) # --- 1. 预处理 (Letterbox) --- class LetterboxPad: def __init__(self, size, fill=(255, 255, 255)): self.size = size if isinstance(size, tuple) else (size, size) self.fill = fill def __call__(self, img): w, h = img.size target_h, target_w = self.size scale = min(target_w / w, target_h / h) new_w = int(w * scale) new_h = int(h * scale) img = img.resize((new_w, new_h), Image.BICUBIC) new_img = Image.new("RGB", (target_w, target_h), self.fill) paste_x = (target_w - new_w) // 2 paste_y = (target_h - new_h) // 2 new_img.paste(img, (paste_x, paste_y)) return new_img def build_transform(size): return transforms.Compose([ LetterboxPad(size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- 2. 加载资源与分组 --- print("Loading model and tags...") device = torch.device("cpu") # 存储不同组的 (name, index) 列表 tag_groups = { "rating": [], "character": [], "general": [] } try: json_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME) with open(json_path, 'r') as f: grouped_json = json.load(f) # 解析分组: 假设 JSON 结构为 {"rating": {"safe": 0, ...}, "general": ...} total_tags = 0 for group_key, tags_dict in grouped_json.items(): # 兼容处理:确保 key 是我们预期的,如果只有 standard tags 可能会归类到 general target_group = group_key if group_key in tag_groups else "general" for name, idx in tags_dict.items(): tag_groups[target_group].append((name, int(idx))) total_tags += 1 print(f"Loaded {total_tags} tags.") print(f" - Rating: {len(tag_groups['rating'])}") print(f" - Character: {len(tag_groups['character'])}") print(f" - General: {len(tag_groups['general'])}") except Exception as e: print(f"Error loading tags: {e}") total_tags = 12000 # Fallback # 加载模型 model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=total_tags) try: model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) print("Model weights loaded.") except Exception as e: print(f"Error loading weights: {e}") model.to(device) model.eval() transform = build_transform(INPUT_SIZE) # --- 3. 推理逻辑 --- @torch.no_grad() def predict(image, threshold_gen, threshold_char): if image is None: return {}, {}, {} img_tensor = transform(image).unsqueeze(0).to(device) logits = model(img_tensor)[0] # Shape: [num_classes] # --- A. 处理 Rating (Softmax) --- rating_res = {} if tag_groups["rating"]: # 提取 rating 对应的 logits r_indices = [idx for _, idx in tag_groups["rating"]] r_names = [name for name, _ in tag_groups["rating"]] # 将 indices 转为 tensor 以便索引 r_indices_tensor = torch.tensor(r_indices, device=device) r_logits = logits[r_indices_tensor] # 核心修改:对 Rating 组内进行 Softmax r_probs = torch.nn.functional.softmax(r_logits, dim=0) for name, prob in zip(r_names, r_probs): rating_res[name] = float(prob) # --- B. 处理 Character (Sigmoid + Threshold) --- char_res = {} if tag_groups["character"]: c_indices = [idx for _, idx in tag_groups["character"]] c_names = [name for name, _ in tag_groups["character"]] c_indices_tensor = torch.tensor(c_indices, device=device) c_logits = logits[c_indices_tensor] c_probs = torch.sigmoid(c_logits) # 多标签使用 Sigmoid for name, prob in zip(c_names, c_probs): if prob > threshold_char: char_res[name] = float(prob) # --- C. 处理 General (Sigmoid + Threshold) --- gen_res = {} if tag_groups["general"]: g_indices = [idx for _, idx in tag_groups["general"]] g_names = [name for name, _ in tag_groups["general"]] g_indices_tensor = torch.tensor(g_indices, device=device) g_logits = logits[g_indices_tensor] g_probs = torch.sigmoid(g_logits) # 多标签使用 Sigmoid for name, prob in zip(g_names, g_probs): if prob > threshold_gen: gen_res[name] = float(prob) # 排序 rating_res = dict(sorted(rating_res.items(), key=lambda x: x[1], reverse=True)) char_res = dict(sorted(char_res.items(), key=lambda x: x[1], reverse=True)) gen_res = dict(sorted(gen_res.items(), key=lambda x: x[1], reverse=True)) return rating_res, char_res, gen_res # --- 4. 界面 --- with gr.Blocks() as demo: gr.Markdown(f"# Anime Tagger (DINOv3)\nModel: {REPO_ID}") with gr.Row(): with gr.Column(scale=1): input_img = gr.Image(type="pil", label="Input Image") run_btn = gr.Button("Tag It!", variant="primary") gr.Markdown("### Thresholds") # 为不同类别设置不同的阈值通常更好,Character 往往需要更低的阈值来召回 threshold_gen = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="General Tags Threshold") threshold_char = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Character Threshold") with gr.Column(scale=1): # 分开显示 gr.Markdown("### 1. Rating (Softmax)") out_rating = gr.Label(label="Rating") gr.Markdown("### 2. Characters") out_char = gr.Label(label="Characters", num_top_classes=10) gr.Markdown("### 3. General Tags") out_gen = gr.Label(label="General Tags", num_top_classes=50) run_btn.click( fn=predict, inputs=[input_img, threshold_gen, threshold_char], outputs=[out_rating, out_char, out_gen] ) if __name__ == "__main__": demo.launch()