telecomadm1145 commited on
Commit
b6b23d6
·
verified ·
1 Parent(s): cfad561

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -83
app.py CHANGED
@@ -1,115 +1,184 @@
1
  import gradio as gr
2
  import torch
3
  import timm
4
- from timm.data import resolve_data_config
5
- from timm.data.transforms_factory import create_transform
6
- from huggingface_hub import hf_hub_download
7
- import json
8
- import os
9
  from PIL import Image
 
 
 
10
 
11
  # --- 配置 ---
12
- # 必须与您训练脚本中的 target_repo 一致
13
- REPO_ID = "telecomadm1145/convnext_dinov3_tagger_test_2w_asl_frozen"
14
- # 必须与您训练时的模型架构名称一致
15
- # 注意:如果在 Space 运行时报错找不到该模型名称,请尝试改为标准的 'convnext_base'
16
- MODEL_NAME = 'convnext_base.dinov3_lvd1689m'
17
- TAGS_FILENAME = "tag_map.json"
18
  MODEL_FILENAME = "pytorch_model.bin"
19
- THRESHOLD = 0.25 # 显示标签的置信度阈值
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- print(f"Loading resources from {REPO_ID}...")
 
 
 
 
 
 
 
 
 
22
 
23
- # 1. 下载并加载标签映射
24
  try:
25
- tags_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME)
26
- with open(tags_path, 'r') as f:
27
- tag_map = json.load(f)
28
- # 将 {'tag': 0, 'tag2': 1} 反转为 {0: 'tag', 1: 'tag2'} 以便索引
29
- idx_to_tag = {v: k for k, v in tag_map.items()}
30
- num_classes = len(tag_map)
31
- print(f"Loaded {num_classes} tags.")
 
 
 
 
 
 
 
 
 
 
 
 
32
  except Exception as e:
33
  print(f"Error loading tags: {e}")
34
- # 提供一个用于调试的空映射,防止 App 直接崩溃
35
- idx_to_tag = {}
36
- num_classes = 12476
37
 
38
- # 2. 准备模型
39
- print(f"Creating model: {MODEL_NAME}")
40
  try:
41
- # 创建模型结构
42
- model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=num_classes)
43
-
44
- # 下载并加载权重
45
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
46
- state_dict = torch.load(model_path, map_location='cpu')
47
  model.load_state_dict(state_dict)
48
- model.eval()
49
- print("Model loaded successfully.")
50
  except Exception as e:
51
- print(f"Error loading model: {e}")
52
- print("Fallback: Attempting to use generic 'convnext_base'...")
53
- try:
54
- # 如果特定的 dinov3 命名在普通环境中不可用,尝试使用基础架构
55
- model = timm.create_model('convnext_base', pretrained=False, num_classes=num_classes)
56
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
57
- state_dict = torch.load(model_path, map_location='cpu')
58
- model.load_state_dict(state_dict)
59
- model.eval()
60
- print("Fallback model loaded successfully.")
61
- except Exception as e2:
62
- raise RuntimeError(f"Failed to load model: {e2}")
63
 
64
- # 3. 准备图像预处理 (Transforms)
65
- config = resolve_data_config({}, model=model)
66
- transform = create_transform(**config)
67
 
68
- # 4. 推理函数
69
  @torch.no_grad()
70
- def predict(image):
71
  if image is None:
72
- return {}
73
 
74
- # 转换图像 (PIL -> Tensor, Normalize, Resize)
75
- img_tensor = transform(image).unsqueeze(0) # Add batch dimension: [1, 3, H, W]
76
 
77
- # 前向传播
78
- logits = model(img_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # 多标签分类使用 Sigmoid
81
- probs = torch.sigmoid(logits)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # 整理结果
84
- results = {}
85
- for idx, score in enumerate(probs):
86
- score_val = float(score)
87
- if score_val > THRESHOLD:
88
- tag_name = idx_to_tag.get(idx, f"Unknown_{idx}")
89
- results[tag_name] = score_val
90
-
91
- # 按照置信度排序 (Gradio 的 Label 组件会自动排序,但手动排序方便调试)
92
- # 限制返回前 50 个标签,防止 UI 过于拥挤
93
- sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)[:50])
94
 
95
- return sorted_results
96
 
97
- # 5. 构建 Gradio 界面
98
  with gr.Blocks() as demo:
 
 
99
  with gr.Row():
100
- with gr.Column():
101
  input_img = gr.Image(type="pil", label="Input Image")
102
- run_btn = gr.Button("Run Tagger", variant="primary")
103
-
104
- with gr.Column():
105
- # Label 组件非常适合显示分概率
106
- output_tags = gr.Label(label="Detected Tags", num_top_classes=30)
107
-
108
- run_btn.click(fn=predict, inputs=input_img, outputs=output_tags)
109
-
110
- examples = []
111
- if examples:
112
- gr.Examples(examples=examples, inputs=input_img)
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  if __name__ == "__main__":
115
- demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
3
  import timm
 
 
 
 
 
4
  from PIL import Image
5
+ import json
6
+ from torchvision import transforms
7
+ from huggingface_hub import hf_hub_download
8
 
9
  # --- 配置 ---
10
+ REPO_ID = "telecomadm1145/convnext_dinov3_tagger_test_epoch_4_asl_letterbox"
 
 
 
 
 
11
  MODEL_FILENAME = "pytorch_model.bin"
12
+ TAGS_FILENAME = "tag_map.json"
13
+ MODEL_NAME = "convnext_base.dinov3_lvd1689m"
14
+ INPUT_SIZE = (512, 512)
15
+
16
+ # --- 1. 预处理 (Letterbox) ---
17
+ class LetterboxPad:
18
+ def __init__(self, size, fill=(255, 255, 255)):
19
+ self.size = size if isinstance(size, tuple) else (size, size)
20
+ self.fill = fill
21
+
22
+ def __call__(self, img):
23
+ w, h = img.size
24
+ target_h, target_w = self.size
25
+ scale = min(target_w / w, target_h / h)
26
+ new_w = int(w * scale)
27
+ new_h = int(h * scale)
28
+ img = img.resize((new_w, new_h), Image.BICUBIC)
29
+ new_img = Image.new("RGB", (target_w, target_h), self.fill)
30
+ paste_x = (target_w - new_w) // 2
31
+ paste_y = (target_h - new_h) // 2
32
+ new_img.paste(img, (paste_x, paste_y))
33
+ return new_img
34
+
35
+ def build_transform(size):
36
+ return transforms.Compose([
37
+ LetterboxPad(size),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40
+ ])
41
 
42
+ # --- 2. 加载资源与分组 ---
43
+ print("Loading model and tags...")
44
+ device = torch.device("cpu")
45
+
46
+ # 存储不同组的 (name, index) 列表
47
+ tag_groups = {
48
+ "rating": [],
49
+ "character": [],
50
+ "general": []
51
+ }
52
 
 
53
  try:
54
+ json_path = hf_hub_download(repo_id=REPO_ID, filename=TAGS_FILENAME)
55
+ with open(json_path, 'r') as f:
56
+ grouped_json = json.load(f)
57
+
58
+ # 解析分组: 假设 JSON 结构为 {"rating": {"safe": 0, ...}, "general": ...}
59
+ total_tags = 0
60
+ for group_key, tags_dict in grouped_json.items():
61
+ # 兼容处理:确保 key 是我们预期的,如果只有 standard tags 可能会归类到 general
62
+ target_group = group_key if group_key in tag_groups else "general"
63
+
64
+ for name, idx in tags_dict.items():
65
+ tag_groups[target_group].append((name, int(idx)))
66
+ total_tags += 1
67
+
68
+ print(f"Loaded {total_tags} tags.")
69
+ print(f" - Rating: {len(tag_groups['rating'])}")
70
+ print(f" - Character: {len(tag_groups['character'])}")
71
+ print(f" - General: {len(tag_groups['general'])}")
72
+
73
  except Exception as e:
74
  print(f"Error loading tags: {e}")
75
+ total_tags = 12000 # Fallback
 
 
76
 
77
+ # 加载模型
78
+ model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=total_tags)
79
  try:
 
 
 
 
80
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
81
+ state_dict = torch.load(model_path, map_location=device)
82
  model.load_state_dict(state_dict)
83
+ print("Model weights loaded.")
 
84
  except Exception as e:
85
+ print(f"Error loading weights: {e}")
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ model.to(device)
88
+ model.eval()
89
+ transform = build_transform(INPUT_SIZE)
90
 
91
+ # --- 3. 推理逻辑 ---
92
  @torch.no_grad()
93
+ def predict(image, threshold_gen, threshold_char):
94
  if image is None:
95
+ return {}, {}, {}
96
 
97
+ img_tensor = transform(image).unsqueeze(0).to(device)
98
+ logits = model(img_tensor)[0] # Shape: [num_classes]
99
 
100
+ # --- A. 处理 Rating (Softmax) ---
101
+ rating_res = {}
102
+ if tag_groups["rating"]:
103
+ # 提取 rating 对应的 logits
104
+ r_indices = [idx for _, idx in tag_groups["rating"]]
105
+ r_names = [name for name, _ in tag_groups["rating"]]
106
+
107
+ # 将 indices 转为 tensor 以便索引
108
+ r_indices_tensor = torch.tensor(r_indices, device=device)
109
+ r_logits = logits[r_indices_tensor]
110
+
111
+ # 核心修改:对 Rating 组内进行 Softmax
112
+ r_probs = torch.nn.functional.softmax(r_logits, dim=0)
113
+
114
+ for name, prob in zip(r_names, r_probs):
115
+ rating_res[name] = float(prob)
116
 
117
+ # --- B. 处理 Character (Sigmoid + Threshold) ---
118
+ char_res = {}
119
+ if tag_groups["character"]:
120
+ c_indices = [idx for _, idx in tag_groups["character"]]
121
+ c_names = [name for name, _ in tag_groups["character"]]
122
+
123
+ c_indices_tensor = torch.tensor(c_indices, device=device)
124
+ c_logits = logits[c_indices_tensor]
125
+ c_probs = torch.sigmoid(c_logits) # 多标签使用 Sigmoid
126
+
127
+ for name, prob in zip(c_names, c_probs):
128
+ if prob > threshold_char:
129
+ char_res[name] = float(prob)
130
+
131
+ # --- C. 处理 General (Sigmoid + Threshold) ---
132
+ gen_res = {}
133
+ if tag_groups["general"]:
134
+ g_indices = [idx for _, idx in tag_groups["general"]]
135
+ g_names = [name for name, _ in tag_groups["general"]]
136
+
137
+ g_indices_tensor = torch.tensor(g_indices, device=device)
138
+ g_logits = logits[g_indices_tensor]
139
+ g_probs = torch.sigmoid(g_logits) # 多标签使用 Sigmoid
140
+
141
+ for name, prob in zip(g_names, g_probs):
142
+ if prob > threshold_gen:
143
+ gen_res[name] = float(prob)
144
 
145
+ # 排序
146
+ rating_res = dict(sorted(rating_res.items(), key=lambda x: x[1], reverse=True))
147
+ char_res = dict(sorted(char_res.items(), key=lambda x: x[1], reverse=True))
148
+ gen_res = dict(sorted(gen_res.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
 
 
 
149
 
150
+ return rating_res, char_res, gen_res
151
 
152
+ # --- 4. 界面 ---
153
  with gr.Blocks() as demo:
154
+ gr.Markdown(f"# Anime Tagger (DINOv3)\nModel: {REPO_ID}")
155
+
156
  with gr.Row():
157
+ with gr.Column(scale=1):
158
  input_img = gr.Image(type="pil", label="Input Image")
159
+ run_btn = gr.Button("Tag It!", variant="primary")
160
+
161
+ gr.Markdown("### Thresholds")
162
+ # 为不同别设置不同的阈值通常更好,Character 往往需要更低的阈值来召回
163
+ threshold_gen = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="General Tags Threshold")
164
+ threshold_char = gr.Slider(0.0, 1.0, value=0.15, step=0.05, label="Character Threshold")
165
+
166
+ with gr.Column(scale=1):
167
+ # 分开显示
168
+ gr.Markdown("### 1. Rating (Softmax)")
169
+ out_rating = gr.Label(label="Rating")
170
+
171
+ gr.Markdown("### 2. Characters")
172
+ out_char = gr.Label(label="Characters", num_top_classes=10)
173
+
174
+ gr.Markdown("### 3. General Tags")
175
+ out_gen = gr.Label(label="General Tags", num_top_classes=50)
176
+
177
+ run_btn.click(
178
+ fn=predict,
179
+ inputs=[input_img, threshold_gen, threshold_char],
180
+ outputs=[out_rating, out_char, out_gen]
181
+ )
182
 
183
  if __name__ == "__main__":
184
+ demo.launch()