Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import timm | |
| from PIL import Image | |
| import json | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| # --- 配置 --- | |
| REPO_ID = "telecomadm1145/convnext_large.dinov3_tagger_2" | |
| MODEL_FILENAME = "pytorch_model.bin" | |
| TAGS_FILENAME = "tag_map.json" | |
| MODEL_NAME = "convnext_large.dinov3_lvd1689m" | |
| INPUT_SIZE = (512,512) | |
| # --- 1. 预处理 (Letterbox) --- | |
| class LetterboxPad: | |
| def __init__(self, size, fill=(255, 255, 255)): | |
| self.size = size if isinstance(size, tuple) else (size, size) | |
| self.fill = fill | |
| def __call__(self, img): | |
| w, h = img.size | |
| target_h, target_w = self.size | |
| scale = min(target_w / w, target_h / h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| img = img.resize((new_w, new_h), Image.BICUBIC) | |
| new_img = Image.new("RGB", (target_w, target_h), self.fill) | |
| paste_x = (target_w - new_w) // 2 | |
| paste_y = (target_h - new_h) // 2 | |
| new_img.paste(img, (paste_x, paste_y)) | |
| return new_img | |
| def build_transform(size): | |
| return transforms.Compose([ | |
| LetterboxPad(size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # --- 2. 加载资源与分组 --- | |
| print("Loading model and tags...") | |
| device = torch.device("cpu") | |
| # 存储不同组的 (name, index) 列表 | |
| tag_groups = { | |
| "rating": [], | |
| "character": [], | |
| "general": [] | |
| } | |
| try: | |
| json_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME) | |
| with open(json_path, 'r') as f: | |
| grouped_json = json.load(f) | |
| # 解析分组: 假设 JSON 结构为 {"rating": {"safe": 0, ...}, "general": ...} | |
| total_tags = 0 | |
| for group_key, tags_dict in grouped_json.items(): | |
| # 兼容处理:确保 key 是我们预期的,如果只有 standard tags 可能会归类到 general | |
| target_group = group_key if group_key in tag_groups else "general" | |
| for name, idx in tags_dict.items(): | |
| tag_groups[target_group].append((name, int(idx))) | |
| total_tags += 1 | |
| print(f"Loaded {total_tags} tags.") | |
| print(f" - Rating: {len(tag_groups['rating'])}") | |
| print(f" - Character: {len(tag_groups['character'])}") | |
| print(f" - General: {len(tag_groups['general'])}") | |
| except Exception as e: | |
| print(f"Error loading tags: {e}") | |
| total_tags = 12000 # Fallback | |
| # 加载模型 | |
| model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=total_tags) | |
| try: | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| print("Model weights loaded.") | |
| except Exception as e: | |
| print(f"Error loading weights: {e}") | |
| model.to(device) | |
| model.eval() | |
| transform = build_transform(INPUT_SIZE) | |
| # --- 3. 推理逻辑 --- | |
| def predict(image, threshold_gen, threshold_char): | |
| if image is None: | |
| return {}, {}, {} | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| logits = model(img_tensor)[0] # Shape: [num_classes] | |
| # --- A. 处理 Rating (Softmax) --- | |
| rating_res = {} | |
| if tag_groups["rating"]: | |
| # 提取 rating 对应的 logits | |
| r_indices = [idx for _, idx in tag_groups["rating"]] | |
| r_names = [name for name, _ in tag_groups["rating"]] | |
| # 将 indices 转为 tensor 以便索引 | |
| r_indices_tensor = torch.tensor(r_indices, device=device) | |
| r_logits = logits[r_indices_tensor] | |
| # 核心修改:对 Rating 组内进行 Softmax | |
| r_probs = torch.nn.functional.softmax(r_logits, dim=0) | |
| for name, prob in zip(r_names, r_probs): | |
| rating_res[name] = float(prob) | |
| # --- B. 处理 Character (Sigmoid + Threshold) --- | |
| char_res = {} | |
| if tag_groups["character"]: | |
| c_indices = [idx for _, idx in tag_groups["character"]] | |
| c_names = [name for name, _ in tag_groups["character"]] | |
| c_indices_tensor = torch.tensor(c_indices, device=device) | |
| c_logits = logits[c_indices_tensor] | |
| c_probs = torch.sigmoid(c_logits) # 多标签使用 Sigmoid | |
| for name, prob in zip(c_names, c_probs): | |
| if prob > threshold_char: | |
| char_res[name] = float(prob) | |
| # --- C. 处理 General (Sigmoid + Threshold) --- | |
| gen_res = {} | |
| if tag_groups["general"]: | |
| g_indices = [idx for _, idx in tag_groups["general"]] | |
| g_names = [name for name, _ in tag_groups["general"]] | |
| g_indices_tensor = torch.tensor(g_indices, device=device) | |
| g_logits = logits[g_indices_tensor] | |
| g_probs = torch.sigmoid(g_logits) # 多标签使用 Sigmoid | |
| for name, prob in zip(g_names, g_probs): | |
| if prob > threshold_gen: | |
| gen_res[name] = float(prob) | |
| # 排序 | |
| rating_res = dict(sorted(rating_res.items(), key=lambda x: x[1], reverse=True)) | |
| char_res = dict(sorted(char_res.items(), key=lambda x: x[1], reverse=True)) | |
| gen_res = dict(sorted(gen_res.items(), key=lambda x: x[1], reverse=True)) | |
| return rating_res, char_res, gen_res | |
| # --- 4. 界面 --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Anime Tagger (DINOv3)\nModel: {REPO_ID}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(type="pil", label="Input Image") | |
| run_btn = gr.Button("Tag It!", variant="primary") | |
| gr.Markdown("### Thresholds") | |
| # 为不同类别设置不同的阈值通常更好,Character 往往需要更低的阈值来召回 | |
| threshold_gen = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="General Tags Threshold") | |
| threshold_char = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Character Threshold") | |
| with gr.Column(scale=1): | |
| # 分开显示 | |
| gr.Markdown("### 1. Rating (Softmax)") | |
| out_rating = gr.Label(label="Rating") | |
| gr.Markdown("### 2. Characters") | |
| out_char = gr.Label(label="Characters", num_top_classes=10) | |
| gr.Markdown("### 3. General Tags") | |
| out_gen = gr.Label(label="General Tags", num_top_classes=50) | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[input_img, threshold_gen, threshold_char], | |
| outputs=[out_rating, out_char, out_gen] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |