Spaces:
Sleeping
Sleeping
File size: 6,528 Bytes
4cad170 b6b23d6 4cad170 15fbeb3 4cad170 b6b23d6 15fbeb3 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 4cad170 b6b23d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import torch
import timm
from PIL import Image
import json
from torchvision import transforms
from huggingface_hub import hf_hub_download
# --- 配置 ---
REPO_ID = "telecomadm1145/convnext_large.dinov3_tagger_2"
MODEL_FILENAME = "pytorch_model.bin"
TAGS_FILENAME = "tag_map.json"
MODEL_NAME = "convnext_large.dinov3_lvd1689m"
INPUT_SIZE = (512,512)
# --- 1. 预处理 (Letterbox) ---
class LetterboxPad:
def __init__(self, size, fill=(255, 255, 255)):
self.size = size if isinstance(size, tuple) else (size, size)
self.fill = fill
def __call__(self, img):
w, h = img.size
target_h, target_w = self.size
scale = min(target_w / w, target_h / h)
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), Image.BICUBIC)
new_img = Image.new("RGB", (target_w, target_h), self.fill)
paste_x = (target_w - new_w) // 2
paste_y = (target_h - new_h) // 2
new_img.paste(img, (paste_x, paste_y))
return new_img
def build_transform(size):
return transforms.Compose([
LetterboxPad(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 2. 加载资源与分组 ---
print("Loading model and tags...")
device = torch.device("cpu")
# 存储不同组的 (name, index) 列表
tag_groups = {
"rating": [],
"character": [],
"general": []
}
try:
json_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME)
with open(json_path, 'r') as f:
grouped_json = json.load(f)
# 解析分组: 假设 JSON 结构为 {"rating": {"safe": 0, ...}, "general": ...}
total_tags = 0
for group_key, tags_dict in grouped_json.items():
# 兼容处理:确保 key 是我们预期的,如果只有 standard tags 可能会归类到 general
target_group = group_key if group_key in tag_groups else "general"
for name, idx in tags_dict.items():
tag_groups[target_group].append((name, int(idx)))
total_tags += 1
print(f"Loaded {total_tags} tags.")
print(f" - Rating: {len(tag_groups['rating'])}")
print(f" - Character: {len(tag_groups['character'])}")
print(f" - General: {len(tag_groups['general'])}")
except Exception as e:
print(f"Error loading tags: {e}")
total_tags = 12000 # Fallback
# 加载模型
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=total_tags)
try:
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
print("Model weights loaded.")
except Exception as e:
print(f"Error loading weights: {e}")
model.to(device)
model.eval()
transform = build_transform(INPUT_SIZE)
# --- 3. 推理逻辑 ---
@torch.no_grad()
def predict(image, threshold_gen, threshold_char):
if image is None:
return {}, {}, {}
img_tensor = transform(image).unsqueeze(0).to(device)
logits = model(img_tensor)[0] # Shape: [num_classes]
# --- A. 处理 Rating (Softmax) ---
rating_res = {}
if tag_groups["rating"]:
# 提取 rating 对应的 logits
r_indices = [idx for _, idx in tag_groups["rating"]]
r_names = [name for name, _ in tag_groups["rating"]]
# 将 indices 转为 tensor 以便索引
r_indices_tensor = torch.tensor(r_indices, device=device)
r_logits = logits[r_indices_tensor]
# 核心修改:对 Rating 组内进行 Softmax
r_probs = torch.nn.functional.softmax(r_logits, dim=0)
for name, prob in zip(r_names, r_probs):
rating_res[name] = float(prob)
# --- B. 处理 Character (Sigmoid + Threshold) ---
char_res = {}
if tag_groups["character"]:
c_indices = [idx for _, idx in tag_groups["character"]]
c_names = [name for name, _ in tag_groups["character"]]
c_indices_tensor = torch.tensor(c_indices, device=device)
c_logits = logits[c_indices_tensor]
c_probs = torch.sigmoid(c_logits) # 多标签使用 Sigmoid
for name, prob in zip(c_names, c_probs):
if prob > threshold_char:
char_res[name] = float(prob)
# --- C. 处理 General (Sigmoid + Threshold) ---
gen_res = {}
if tag_groups["general"]:
g_indices = [idx for _, idx in tag_groups["general"]]
g_names = [name for name, _ in tag_groups["general"]]
g_indices_tensor = torch.tensor(g_indices, device=device)
g_logits = logits[g_indices_tensor]
g_probs = torch.sigmoid(g_logits) # 多标签使用 Sigmoid
for name, prob in zip(g_names, g_probs):
if prob > threshold_gen:
gen_res[name] = float(prob)
# 排序
rating_res = dict(sorted(rating_res.items(), key=lambda x: x[1], reverse=True))
char_res = dict(sorted(char_res.items(), key=lambda x: x[1], reverse=True))
gen_res = dict(sorted(gen_res.items(), key=lambda x: x[1], reverse=True))
return rating_res, char_res, gen_res
# --- 4. 界面 ---
with gr.Blocks() as demo:
gr.Markdown(f"# Anime Tagger (DINOv3)\nModel: {REPO_ID}")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Image")
run_btn = gr.Button("Tag It!", variant="primary")
gr.Markdown("### Thresholds")
# 为不同类别设置不同的阈值通常更好,Character 往往需要更低的阈值来召回
threshold_gen = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="General Tags Threshold")
threshold_char = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Character Threshold")
with gr.Column(scale=1):
# 分开显示
gr.Markdown("### 1. Rating (Softmax)")
out_rating = gr.Label(label="Rating")
gr.Markdown("### 2. Characters")
out_char = gr.Label(label="Characters", num_top_classes=10)
gr.Markdown("### 3. General Tags")
out_gen = gr.Label(label="General Tags", num_top_classes=50)
run_btn.click(
fn=predict,
inputs=[input_img, threshold_gen, threshold_char],
outputs=[out_rating, out_char, out_gen]
)
if __name__ == "__main__":
demo.launch()
|