IdlecloudX commited on
Commit
0f7b781
·
verified ·
1 Parent(s): 73b5b76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -143
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
  import json
 
 
3
  import warnings
 
 
 
4
 
5
  import gradio as gr
6
- import huggingface_hub
7
- import numpy as np
8
- import onnxruntime as rt
9
- import pandas as pd
10
  from PIL import Image, ImageFile
11
- from huggingface_hub import login
12
 
 
13
  from translator import translate_texts
14
 
15
  # ------------------------------------------------------------------
@@ -97,7 +99,6 @@ def validate_and_open_image(image_path: str) -> Image.Image:
97
  f"图片总像素过大:{total_pixels:,},超过限制 {MAX_IMAGE_PIXELS:,}。"
98
  )
99
 
100
- # 估算解码为 RGB 后的内存占用
101
  estimated_decompressed_bytes = total_pixels * 3
102
  if estimated_decompressed_bytes > MAX_DECOMPRESSED_BYTES:
103
  raise ImageValidationError(
@@ -106,7 +107,6 @@ def validate_and_open_image(image_path: str) -> Image.Image:
106
  f"超过限制 {_format_size(MAX_DECOMPRESSED_BYTES)}。"
107
  )
108
 
109
- # 第二次打开,真正加载像素数据
110
  try:
111
  with Image.open(image_path) as img:
112
  img.load()
@@ -123,108 +123,195 @@ def validate_and_open_image(image_path: str) -> Image.Image:
123
 
124
 
125
  # ------------------------------------------------------------------
126
- # 模型配置
127
  # ------------------------------------------------------------------
128
- MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
129
- MODEL_FILENAME = "model.onnx"
130
- LABEL_FILENAME = "selected_tags.csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
133
- if HF_TOKEN:
134
- login(token=HF_TOKEN)
135
- else:
136
- print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
137
 
138
  # ------------------------------------------------------------------
139
- # Tagger 类 (全局实例化)
140
  # ------------------------------------------------------------------
141
  class Tagger:
142
  def __init__(self):
143
- self.hf_token = HF_TOKEN
144
- self.tag_names = []
145
- self.categories = {}
146
- self.model = None
147
- self.input_size = 0
148
  self._load_model_and_labels()
149
 
150
- def _load_model_and_labels(self):
151
  try:
152
- label_path = huggingface_hub.hf_hub_download(
153
- MODEL_REPO, LABEL_FILENAME, token=self.hf_token, resume_download=True
154
- )
155
- model_path = huggingface_hub.hf_hub_download(
156
- MODEL_REPO, MODEL_FILENAME, token=self.hf_token, resume_download=True
157
- )
158
-
159
- tags_df = pd.read_csv(label_path)
160
- self.tag_names = tags_df["name"].tolist()
161
- self.categories = {
162
- "rating": np.where(tags_df["category"] == 9)[0],
163
- "general": np.where(tags_df["category"] == 0)[0],
164
- "character": np.where(tags_df["category"] == 4)[0],
165
- }
166
- self.model = rt.InferenceSession(model_path)
167
- self.input_size = self.model.get_inputs()[0].shape[1]
168
- print("✅ 模型和标签加载成功")
169
  except Exception as e:
170
- print(f"❌ 模型或标签加载失败: {e}")
171
- raise RuntimeError(f"模型初始化失败: {e}")
172
-
173
- def _preprocess(self, img: Image.Image) -> np.ndarray:
174
- if img is None:
175
- raise ValueError("输入图像不能为空")
176
- if img.mode != "RGB":
177
- img = img.convert("RGB")
178
- size = max(img.size)
179
- canvas = Image.new("RGB", (size, size), (255, 255, 255))
180
- canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
181
- if size != self.input_size:
182
- canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
183
- return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
184
-
185
- def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
186
- if self.model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  raise RuntimeError("模型未成功加载,无法进行预测。")
188
- inp_name = self.model.get_inputs()[0].name
189
- outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
190
-
191
- res = {"ratings": {}, "general": {}, "characters": {}}
192
- tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
193
-
194
- for idx in self.categories["rating"]:
195
- tag_name = self.tag_names[idx].replace("_", " ")
196
- res["ratings"][tag_name] = float(outputs[idx])
197
- tag_categories_for_translation["ratings"].append(tag_name)
198
-
199
- for idx in self.categories["general"]:
200
- if outputs[idx] > gen_th:
201
- tag_name = self.tag_names[idx].replace("_", " ")
202
- res["general"][tag_name] = float(outputs[idx])
203
- tag_categories_for_translation["general"].append(tag_name)
204
-
205
- for idx in self.categories["character"]:
206
- if outputs[idx] > char_th:
207
- tag_name = self.tag_names[idx].replace("_", " ")
208
- res["characters"][tag_name] = float(outputs[idx])
209
- tag_categories_for_translation["characters"].append(tag_name)
210
-
211
- res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True))
212
- res["characters"] = dict(sorted(res["characters"].items(), key=lambda kv: kv[1], reverse=True))
213
- res["ratings"] = dict(sorted(res["ratings"].items(), key=lambda kv: kv[1], reverse=True))
214
 
215
- tag_categories_for_translation["general"] = list(res["general"].keys())
216
- tag_categories_for_translation["characters"] = list(res["characters"].keys())
217
- tag_categories_for_translation["ratings"] = list(res["ratings"].keys())
218
-
219
- return res, tag_categories_for_translation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
 
222
  # 全局 Tagger 实例
223
  try:
224
  tagger_instance = Tagger()
225
  except RuntimeError as e:
226
- print(f"应用启动时Tagger初始化失败: {e}")
227
- tagger_instance = None # 允许应用启动,但在处理时会失败
 
 
 
 
 
 
228
 
229
  # ------------------------------------------------------------------
230
  # Gradio UI
@@ -263,6 +350,7 @@ custom_css = """
263
  .tag-score {
264
  color: #999;
265
  font-size: 0.9em;
 
266
  }
267
  .btn-analyze-container {
268
  margin-top: 15px;
@@ -333,11 +421,14 @@ function copyToClipboard(text) {
333
  }
334
  """
335
 
 
336
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
337
  gr.Markdown("# 🖼️ AI 图像标签分析器")
338
  gr.Markdown(
339
  "上传图片自动识别标签,支持中英文显示和一键复制。"
340
  "[NovelAI在线绘画](https://nai.idlecloud.cc/)\n\n"
 
 
341
  )
342
 
343
  state_res = gr.State({})
@@ -346,22 +437,39 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
346
 
347
  with gr.Row():
348
  with gr.Column(scale=1):
349
- # 改为 filepath,确保可以拿到原始文件路径与体积进行校验
350
  img_in = gr.Image(type="filepath", label="上传图片", height=300)
351
 
352
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
353
 
354
  with gr.Accordion("⚙️ 高级设置", open=False):
355
- gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准")
356
- char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值")
357
- show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  with gr.Accordion("📊 标签汇总设置", open=True):
360
  gr.Markdown("选择要包含在下方汇总文本框中的标签类别:")
361
  with gr.Row():
362
  sum_general = gr.Checkbox(True, label="通用标签", min_width=50)
363
  sum_char = gr.Checkbox(True, label="角色标签", min_width=50)
364
- sum_rating = gr.Checkbox(False, label="评分标签", min_width=50)
365
  sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符")
366
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
367
 
@@ -372,19 +480,24 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
372
  with gr.TabItem("🏷️ 通用标签"):
373
  out_general = gr.HTML(label="General Tags")
374
  with gr.TabItem("👤 角色标签"):
375
- gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签推测基于截至2024年2月的数据。</p>")
376
  out_char = gr.HTML(label="Character Tags")
377
- with gr.TabItem(" 评分标签"):
378
- out_rating = gr.HTML(label="Rating Tags")
 
379
 
380
  gr.Markdown("### 标签汇总结果")
381
  out_summary = gr.Textbox(
382
  label="标签汇总",
383
  placeholder="分析完成后,此处将显示汇总的英文标签...",
384
  lines=5,
385
- show_copy_button=True
386
  )
387
 
 
 
 
 
388
  # ----------------- 辅助函数 -----------------
389
  def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True):
390
  if not tags_dict:
@@ -399,24 +512,37 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
399
 
400
  for i, tag in enumerate(tag_keys):
401
  score = tags_dict[tag]
402
- escaped_tag = tag.replace("'", "\\'")
 
403
 
404
  html += '<div class="tag-item">'
405
- tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
 
 
 
406
 
407
  if show_translation_in_list and i < len(translations_list) and translations_list[i]:
408
- tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
 
 
409
 
410
- html += f'<div>{tag_display_html}</div>'
411
- if show_scores:
412
  html += f'<span class="tag-score">{score:.3f}</span>'
413
- html += '</div>'
414
- html += '</div>'
 
 
415
  return html
416
 
 
417
  def generate_summary_text_content(
418
- current_res, current_translations_dict,
419
- s_gen, s_char, s_rat, s_sep_type, s_show_zh
 
 
 
 
 
420
  ):
421
  if not current_res:
422
  return "请先分析图像或选择要汇总的标签类别。"
@@ -430,8 +556,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
430
  categories_to_summarize.append("general")
431
  if s_char:
432
  categories_to_summarize.append("characters")
433
- if s_rat:
434
- categories_to_summarize.append("ratings")
435
 
436
  if not categories_to_summarize:
437
  return "请至少选择一个标签类别进行汇总。"
@@ -447,6 +573,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
447
  tags_to_join.append(f"{en_tag}/*{cat_translations[i]}*/")
448
  else:
449
  tags_to_join.append(en_tag)
 
450
  if tags_to_join:
451
  summary_parts.append(separator.join(tags_to_join))
452
 
@@ -455,16 +582,30 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
455
  final_summary = joiner.join(summary_parts)
456
  return final_summary if final_summary else "选定的类别中没有找到标签。"
457
 
 
458
  def process_image_and_generate_outputs(
459
- image_path, g_th, c_th, s_scores,
460
- s_gen, s_char, s_rat, s_sep, s_zh_in_sum
 
 
 
 
 
 
 
461
  ):
462
  if image_path is None:
463
  yield (
464
  gr.update(interactive=True, value="🚀 开始分析"),
465
  gr.update(visible=True, value="❌ 请先上传图片。"),
466
- "", "", "", "",
467
- {}, {}, {}
 
 
 
 
 
 
468
  )
469
  return
470
 
@@ -472,8 +613,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
472
  yield (
473
  gr.update(interactive=True, value="🚀 开始分析"),
474
  gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"),
475
- "", "", "", "",
476
- {}, {}, {}
 
 
 
 
 
 
477
  )
478
  return
479
 
@@ -484,26 +631,34 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
484
  gr.HTML(value="<p>分析中...</p>"),
485
  gr.HTML(value="<p>分析中...</p>"),
486
  gr.update(value="分析中,请稍候..."),
487
- {}, {}, {}
 
 
 
488
  )
489
 
490
  try:
491
  img = validate_and_open_image(image_path)
492
- res, tag_categories_original_order = tagger_instance.predict(img, g_th, c_th)
493
 
494
  all_tags_to_translate = []
495
- for cat_key in ["general", "characters", "ratings"]:
496
  all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, []))
497
 
498
  all_translations_flat = []
499
  if all_tags_to_translate:
500
- all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh")
 
 
 
 
501
 
502
  current_translations_dict = {}
503
  offset = 0
504
- for cat_key in ["general", "characters", "ratings"]:
505
  cat_original_tags = tag_categories_original_order.get(cat_key, [])
506
  num_tags_in_cat = len(cat_original_tags)
 
507
  if num_tags_in_cat > 0:
508
  current_translations_dict[cat_key] = all_translations_flat[offset: offset + num_tags_in_cat]
509
  offset += num_tags_in_cat
@@ -524,17 +679,22 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
524
  s_scores,
525
  True,
526
  )
527
- rating_html = format_tags_html(
528
- res.get("ratings", {}),
529
- current_translations_dict.get("ratings", []),
530
- "ratings",
531
  s_scores,
532
  True,
533
  )
534
 
535
  summary_text = generate_summary_text_content(
536
- res, current_translations_dict,
537
- s_gen, s_char, s_rat, s_sep, s_zh_in_sum
 
 
 
 
 
538
  )
539
 
540
  yield (
@@ -542,11 +702,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
542
  gr.update(visible=True, value="✅ 分析完成!"),
543
  general_html,
544
  char_html,
545
- rating_html,
546
  gr.update(value=summary_text),
547
  res,
548
  current_translations_dict,
549
- tag_categories_original_order
 
550
  )
551
 
552
  except ImageValidationError as e:
@@ -557,48 +718,82 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
557
  "<p>图片已被安全策略拒绝</p>",
558
  "<p>图片已被安全策略拒绝</p>",
559
  gr.update(value=f"错误: {str(e)}", placeholder="上传图片未通过安全校验..."),
560
- {}, {}, {}
 
 
 
561
  )
562
  except Exception as e:
563
  import traceback
 
564
  tb_str = traceback.format_exc()
565
  print(f"处理时发生错误: {e}\n{tb_str}")
566
  yield (
567
  gr.update(interactive=True, value="🚀 开始分析"),
568
  gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
569
- "<p>处理出错</p>", "<p>处理出错</p>", "<p>处理出错</p>",
 
 
570
  gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."),
571
- {}, {}, {}
 
 
 
572
  )
573
 
 
574
  def update_summary_display(
575
- s_gen, s_char, s_rat, s_sep, s_zh_in_sum,
576
- current_res_from_state, current_translations_from_state
 
 
 
 
 
577
  ):
578
  if not current_res_from_state:
579
  return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="")
580
 
581
  new_summary_text = generate_summary_text_content(
582
- current_res_from_state, current_translations_from_state,
583
- s_gen, s_char, s_rat, s_sep, s_zh_in_sum
 
 
 
 
 
584
  )
585
  return gr.update(value=new_summary_text)
586
 
 
587
  btn.click(
588
  process_image_and_generate_outputs,
589
  inputs=[
590
- img_in, gen_slider, char_slider, show_tag_scores,
591
- sum_general, sum_char, sum_rating, sum_sep, sum_show_zh
 
 
 
 
 
 
 
592
  ],
593
  outputs=[
594
- btn, processing_info,
595
- out_general, out_char, out_rating,
 
 
 
596
  out_summary,
597
- state_res, state_translations_dict, state_tag_categories_for_translation
 
 
 
598
  ],
599
  )
600
 
601
- summary_controls = [sum_general, sum_char, sum_rating, sum_sep, sum_show_zh]
602
  for ctrl in summary_controls:
603
  ctrl.change(
604
  fn=update_summary_display,
@@ -606,7 +801,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
606
  outputs=[out_summary],
607
  )
608
 
 
609
  if __name__ == "__main__":
610
  if tagger_instance is None:
611
  print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。")
612
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import json
3
+ import time
4
+ import shutil
5
  import warnings
6
+ from html import escape
7
+ from pathlib import Path
8
+ from typing import Optional
9
 
10
  import gradio as gr
11
+ from huggingface_hub import snapshot_download
 
 
 
12
  from PIL import Image, ImageFile
 
13
 
14
+ from handler import EndpointHandler
15
  from translator import translate_texts
16
 
17
  # ------------------------------------------------------------------
 
99
  f"图片总像素过大:{total_pixels:,},超过限制 {MAX_IMAGE_PIXELS:,}。"
100
  )
101
 
 
102
  estimated_decompressed_bytes = total_pixels * 3
103
  if estimated_decompressed_bytes > MAX_DECOMPRESSED_BYTES:
104
  raise ImageValidationError(
 
107
  f"超过限制 {_format_size(MAX_DECOMPRESSED_BYTES)}。"
108
  )
109
 
 
110
  try:
111
  with Image.open(image_path) as img:
112
  img.load()
 
123
 
124
 
125
  # ------------------------------------------------------------------
126
+ # 新版 PixAI Tagger v0.9 模型配置
127
  # ------------------------------------------------------------------
128
+ ASSETS_REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
129
+ ASSETS_REVISION = os.environ.get("ASSETS_REVISION")
130
+ MODEL_DIR = os.environ.get("MODEL_DIR", "./assets")
131
+
132
+ HF_TOKEN = (
133
+ os.environ.get("HUGGINGFACE_HUB_TOKEN")
134
+ or os.environ.get("HF_TOKEN")
135
+ or os.environ.get("HUGGINGFACE_TOKEN")
136
+ or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
137
+ )
138
+
139
+ REQUIRED_FILES = [
140
+ "model_v0.9.pth",
141
+ "tags_v0.9_13k.json",
142
+ "char_ip_map.json",
143
+ ]
144
+
145
+
146
+ def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str) -> None:
147
+ """
148
+ 下载 pixai-labs/pixai-tagger-v0.9 所需资源,并复制到 handler 期望的本地目录。
149
+ 如果文件已经存在,则不会重复下载。
150
+ """
151
+ target = Path(target_dir)
152
+ target.mkdir(parents=True, exist_ok=True)
153
+
154
+ missing = [fname for fname in REQUIRED_FILES if not (target / fname).exists()]
155
+ if not missing:
156
+ return
157
+
158
+ snapshot_path = snapshot_download(
159
+ repo_id=repo_id,
160
+ revision=revision,
161
+ allow_patterns=REQUIRED_FILES,
162
+ token=HF_TOKEN,
163
+ )
164
+
165
+ for fname in REQUIRED_FILES:
166
+ src = Path(snapshot_path) / fname
167
+ dst = target / fname
168
+
169
+ if not src.exists():
170
+ raise FileNotFoundError(
171
+ f"模型资源缺失:'{fname}' 未在 {repo_id} @ {revision or 'default'} 中找到。"
172
+ )
173
+
174
+ if src.resolve() != dst.resolve():
175
+ shutil.copyfile(src, dst)
176
 
 
 
 
 
 
177
 
178
  # ------------------------------------------------------------------
179
+ # Tagger 类:使用新版 EndpointHandler
180
  # ------------------------------------------------------------------
181
  class Tagger:
182
  def __init__(self):
183
+ self.handler = None
184
+ self.device = "unknown"
 
 
 
185
  self._load_model_and_labels()
186
 
187
+ def _load_model_and_labels(self) -> None:
188
  try:
189
+ ensure_assets(ASSETS_REPO_ID, ASSETS_REVISION, MODEL_DIR)
190
+ self.handler = EndpointHandler(MODEL_DIR)
191
+ self.device = getattr(self.handler, "device", "unknown")
192
+ print(f"✅ PixAI Tagger v0.9 加载成功,设备:{str(self.device).upper()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  except Exception as e:
194
+ print(f"❌ PixAI Tagger v0.9 加载失败: {e}")
195
+ raise RuntimeError(f"模型初始化失败: {e}") from e
196
+
197
+ @staticmethod
198
+ def _display_tag(tag: str) -> str:
199
+ return str(tag).replace("_", " ")
200
+
201
+ @staticmethod
202
+ def _get_score(scores: dict, tag: str) -> float:
203
+ """
204
+ handler 通常以原始 tag 作为分数字典 key。
205
+ 这里额外兼容空格/下划线两种写法,避免 key 不一致时取不到分数。
206
+ """
207
+ if not isinstance(scores, dict):
208
+ return 0.0
209
+
210
+ candidates = [
211
+ tag,
212
+ str(tag).replace("_", " "),
213
+ str(tag).replace(" ", "_"),
214
+ ]
215
+
216
+ for key in candidates:
217
+ if key in scores:
218
+ try:
219
+ return float(scores[key])
220
+ except Exception:
221
+ return 0.0
222
+
223
+ return 0.0
224
+
225
+ def predict(self, img: Image.Image, gen_th: float = 0.30, char_th: float = 0.85):
226
+ """
227
+ 返回结构保持原 app.py 的 UI 处理习惯:
228
+ - general:通用/特征标签,带置信度
229
+ - characters:角色标签,带置信度
230
+ - ips:IP 标签,新模型不返回评分标签,因此原 ratings 改为 ips,且 IP 不展示伪造置信度
231
+ """
232
+ if self.handler is None:
233
  raise RuntimeError("模型未成功加载,无法进行预测。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ if img is None:
236
+ raise ValueError("输入图像不能为空。")
237
+
238
+ params = {
239
+ "general_threshold": float(gen_th),
240
+ "character_threshold": float(char_th),
241
+ "mode": "threshold",
242
+ "topk_general": 25,
243
+ "topk_character": 10,
244
+ "include_scores": True,
245
+ }
246
+
247
+ data = {
248
+ "inputs": img,
249
+ "parameters": params,
250
+ }
251
+
252
+ started = time.time()
253
+ out = self.handler(data)
254
+ latency = round(time.time() - started, 4)
255
+
256
+ feature_tags = out.get("feature", []) or []
257
+ character_tags = out.get("character", []) or []
258
+ ip_tags = out.get("ip", []) or []
259
+
260
+ feature_scores = out.get("feature_scores", {}) or {}
261
+ character_scores = out.get("character_scores", {}) or {}
262
+
263
+ general = {
264
+ self._display_tag(tag): self._get_score(feature_scores, tag)
265
+ for tag in feature_tags
266
+ }
267
+ characters = {
268
+ self._display_tag(tag): self._get_score(character_scores, tag)
269
+ for tag in character_tags
270
+ }
271
+
272
+ # IP 标签没有评分,使用 None 表示“不显示置信度”
273
+ ips = {
274
+ self._display_tag(tag): None
275
+ for tag in ip_tags
276
+ }
277
+
278
+ general = dict(sorted(general.items(), key=lambda kv: kv[1], reverse=True))
279
+ characters = dict(sorted(characters.items(), key=lambda kv: kv[1], reverse=True))
280
+
281
+ res = {
282
+ "general": general,
283
+ "characters": characters,
284
+ "ips": ips,
285
+ }
286
+
287
+ tag_categories_for_translation = {
288
+ "general": list(general.keys()),
289
+ "characters": list(characters.keys()),
290
+ "ips": list(ips.keys()),
291
+ }
292
+
293
+ raw_meta = {
294
+ "device": str(self.device),
295
+ "latency_s_total": latency,
296
+ "_params": out.get("_params", params),
297
+ "_timings": out.get("_timings", {}),
298
+ }
299
+
300
+ return res, tag_categories_for_translation, raw_meta
301
 
302
 
303
  # 全局 Tagger 实例
304
  try:
305
  tagger_instance = Tagger()
306
  except RuntimeError as e:
307
+ print(f"应用启动时 Tagger 初始化失败: {e}")
308
+ tagger_instance = None
309
+
310
+ DEVICE_LABEL = (
311
+ f"设备:{str(tagger_instance.device).upper()}"
312
+ if tagger_instance is not None
313
+ else "设备:UNKNOWN"
314
+ )
315
 
316
  # ------------------------------------------------------------------
317
  # Gradio UI
 
350
  .tag-score {
351
  color: #999;
352
  font-size: 0.9em;
353
+ white-space: nowrap;
354
  }
355
  .btn-analyze-container {
356
  margin-top: 15px;
 
421
  }
422
  """
423
 
424
+
425
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
426
  gr.Markdown("# 🖼️ AI 图像标签分析器")
427
  gr.Markdown(
428
  "上传图片自动识别标签,支持中英文显示和一键复制。"
429
  "[NovelAI在线绘画](https://nai.idlecloud.cc/)\n\n"
430
+ f"**当前模型:pixai-labs/pixai-tagger-v0.9** | **{DEVICE_LABEL}**\n\n"
431
+ "说明:新版模型不再返回评分标签,本页面已将原“评分标签”区域改为“IP 标签”。"
432
  )
433
 
434
  state_res = gr.State({})
 
437
 
438
  with gr.Row():
439
  with gr.Column(scale=1):
 
440
  img_in = gr.Image(type="filepath", label="上传图片", height=300)
441
 
442
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
443
 
444
  with gr.Accordion("⚙️ 高级设置", open=False):
445
+ gen_slider = gr.Slider(
446
+ 0,
447
+ 1,
448
+ value=0.30,
449
+ step=0.01,
450
+ label="通用标签阈值",
451
+ info="越高 → 标签更少更准",
452
+ )
453
+ char_slider = gr.Slider(
454
+ 0,
455
+ 1,
456
+ value=0.85,
457
+ step=0.01,
458
+ label="角色标签阈值",
459
+ info="推荐保持较高阈值",
460
+ )
461
+ show_tag_scores = gr.Checkbox(
462
+ True,
463
+ label="在列表中显示标签置信度",
464
+ info="IP 标签不返回置信度,因此不会显示分数。",
465
+ )
466
 
467
  with gr.Accordion("📊 标签汇总设置", open=True):
468
  gr.Markdown("选择要包含在下方汇总文本框中的标签类别:")
469
  with gr.Row():
470
  sum_general = gr.Checkbox(True, label="通用标签", min_width=50)
471
  sum_char = gr.Checkbox(True, label="角色标签", min_width=50)
472
+ sum_ip = gr.Checkbox(False, label="IP 标签", min_width=50)
473
  sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符")
474
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
475
 
 
480
  with gr.TabItem("🏷️ 通用标签"):
481
  out_general = gr.HTML(label="General Tags")
482
  with gr.TabItem("👤 角色标签"):
483
+ gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签由模型断,建议保持较高阈值。</p>")
484
  out_char = gr.HTML(label="Character Tags")
485
+ with gr.TabItem("🌐 IP 标签"):
486
+ gr.Markdown("<p style='color:gray; font-size:small;'>提示:新版模型输出 IP 标签,但不返回评分标签/评分置信度。</p>")
487
+ out_ip = gr.HTML(label="IP Tags")
488
 
489
  gr.Markdown("### 标签汇总结果")
490
  out_summary = gr.Textbox(
491
  label="标签汇总",
492
  placeholder="分析完成后,此处将显示汇总的英文标签...",
493
  lines=5,
494
+ show_copy_button=True,
495
  )
496
 
497
+ with gr.Accordion("🧾 推理元数据", open=False):
498
+ out_meta = gr.JSON(label="Metadata")
499
+
500
+
501
  # ----------------- 辅助函数 -----------------
502
  def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True):
503
  if not tags_dict:
 
512
 
513
  for i, tag in enumerate(tag_keys):
514
  score = tags_dict[tag]
515
+ safe_tag_text = escape(str(tag))
516
+ js_arg = json.dumps(str(tag), ensure_ascii=False)
517
 
518
  html += '<div class="tag-item">'
519
+
520
+ tag_display_html = (
521
+ f'<span class="tag-en" onclick=\'copyToClipboard({js_arg})\'>{safe_tag_text}</span>'
522
+ )
523
 
524
  if show_translation_in_list and i < len(translations_list) and translations_list[i]:
525
+ tag_display_html += f'<span class="tag-zh">({escape(str(translations_list[i]))})</span>'
526
+
527
+ html += f"<div>{tag_display_html}</div>"
528
 
529
+ if show_scores and isinstance(score, (int, float)):
 
530
  html += f'<span class="tag-score">{score:.3f}</span>'
531
+
532
+ html += "</div>"
533
+
534
+ html += "</div>"
535
  return html
536
 
537
+
538
  def generate_summary_text_content(
539
+ current_res,
540
+ current_translations_dict,
541
+ s_gen,
542
+ s_char,
543
+ s_ip,
544
+ s_sep_type,
545
+ s_show_zh,
546
  ):
547
  if not current_res:
548
  return "请先分析图像或选择要汇总的标签类别。"
 
556
  categories_to_summarize.append("general")
557
  if s_char:
558
  categories_to_summarize.append("characters")
559
+ if s_ip:
560
+ categories_to_summarize.append("ips")
561
 
562
  if not categories_to_summarize:
563
  return "请至少选择一个标签类别进行汇总。"
 
573
  tags_to_join.append(f"{en_tag}/*{cat_translations[i]}*/")
574
  else:
575
  tags_to_join.append(en_tag)
576
+
577
  if tags_to_join:
578
  summary_parts.append(separator.join(tags_to_join))
579
 
 
582
  final_summary = joiner.join(summary_parts)
583
  return final_summary if final_summary else "选定的类别中没有找到标签。"
584
 
585
+
586
  def process_image_and_generate_outputs(
587
+ image_path,
588
+ g_th,
589
+ c_th,
590
+ s_scores,
591
+ s_gen,
592
+ s_char,
593
+ s_ip,
594
+ s_sep,
595
+ s_zh_in_sum,
596
  ):
597
  if image_path is None:
598
  yield (
599
  gr.update(interactive=True, value="🚀 开始分析"),
600
  gr.update(visible=True, value="❌ 请先上传图片。"),
601
+ "",
602
+ "",
603
+ "",
604
+ "",
605
+ {},
606
+ {},
607
+ {},
608
+ {},
609
  )
610
  return
611
 
 
613
  yield (
614
  gr.update(interactive=True, value="🚀 开始分析"),
615
  gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"),
616
+ "",
617
+ "",
618
+ "",
619
+ "",
620
+ {},
621
+ {},
622
+ {},
623
+ {},
624
  )
625
  return
626
 
 
631
  gr.HTML(value="<p>分析中...</p>"),
632
  gr.HTML(value="<p>分析中...</p>"),
633
  gr.update(value="分析中,请稍候..."),
634
+ {},
635
+ {},
636
+ {},
637
+ {},
638
  )
639
 
640
  try:
641
  img = validate_and_open_image(image_path)
642
+ res, tag_categories_original_order, meta = tagger_instance.predict(img, g_th, c_th)
643
 
644
  all_tags_to_translate = []
645
+ for cat_key in ["general", "characters", "ips"]:
646
  all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, []))
647
 
648
  all_translations_flat = []
649
  if all_tags_to_translate:
650
+ try:
651
+ all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh")
652
+ except Exception as translate_error:
653
+ print(f"⚠️ 标签翻译失败,将仅显示英文标签:{translate_error}")
654
+ all_translations_flat = [""] * len(all_tags_to_translate)
655
 
656
  current_translations_dict = {}
657
  offset = 0
658
+ for cat_key in ["general", "characters", "ips"]:
659
  cat_original_tags = tag_categories_original_order.get(cat_key, [])
660
  num_tags_in_cat = len(cat_original_tags)
661
+
662
  if num_tags_in_cat > 0:
663
  current_translations_dict[cat_key] = all_translations_flat[offset: offset + num_tags_in_cat]
664
  offset += num_tags_in_cat
 
679
  s_scores,
680
  True,
681
  )
682
+ ip_html = format_tags_html(
683
+ res.get("ips", {}),
684
+ current_translations_dict.get("ips", []),
685
+ "ips",
686
  s_scores,
687
  True,
688
  )
689
 
690
  summary_text = generate_summary_text_content(
691
+ res,
692
+ current_translations_dict,
693
+ s_gen,
694
+ s_char,
695
+ s_ip,
696
+ s_sep,
697
+ s_zh_in_sum,
698
  )
699
 
700
  yield (
 
702
  gr.update(visible=True, value="✅ 分析完成!"),
703
  general_html,
704
  char_html,
705
+ ip_html,
706
  gr.update(value=summary_text),
707
  res,
708
  current_translations_dict,
709
+ tag_categories_original_order,
710
+ meta,
711
  )
712
 
713
  except ImageValidationError as e:
 
718
  "<p>图片已被安全策略拒绝</p>",
719
  "<p>图片已被安全策略拒绝</p>",
720
  gr.update(value=f"错误: {str(e)}", placeholder="上传图片未通过安全校验..."),
721
+ {},
722
+ {},
723
+ {},
724
+ {},
725
  )
726
  except Exception as e:
727
  import traceback
728
+
729
  tb_str = traceback.format_exc()
730
  print(f"处理时发生错误: {e}\n{tb_str}")
731
  yield (
732
  gr.update(interactive=True, value="🚀 开始分析"),
733
  gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"),
734
+ "<p>处理出错</p>",
735
+ "<p>处理出错</p>",
736
+ "<p>处理出错</p>",
737
  gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."),
738
+ {},
739
+ {},
740
+ {},
741
+ {},
742
  )
743
 
744
+
745
  def update_summary_display(
746
+ s_gen,
747
+ s_char,
748
+ s_ip,
749
+ s_sep,
750
+ s_zh_in_sum,
751
+ current_res_from_state,
752
+ current_translations_from_state,
753
  ):
754
  if not current_res_from_state:
755
  return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="")
756
 
757
  new_summary_text = generate_summary_text_content(
758
+ current_res_from_state,
759
+ current_translations_from_state,
760
+ s_gen,
761
+ s_char,
762
+ s_ip,
763
+ s_sep,
764
+ s_zh_in_sum,
765
  )
766
  return gr.update(value=new_summary_text)
767
 
768
+
769
  btn.click(
770
  process_image_and_generate_outputs,
771
  inputs=[
772
+ img_in,
773
+ gen_slider,
774
+ char_slider,
775
+ show_tag_scores,
776
+ sum_general,
777
+ sum_char,
778
+ sum_ip,
779
+ sum_sep,
780
+ sum_show_zh,
781
  ],
782
  outputs=[
783
+ btn,
784
+ processing_info,
785
+ out_general,
786
+ out_char,
787
+ out_ip,
788
  out_summary,
789
+ state_res,
790
+ state_translations_dict,
791
+ state_tag_categories_for_translation,
792
+ out_meta,
793
  ],
794
  )
795
 
796
+ summary_controls = [sum_general, sum_char, sum_ip, sum_sep, sum_show_zh]
797
  for ctrl in summary_controls:
798
  ctrl.change(
799
  fn=update_summary_display,
 
801
  outputs=[out_summary],
802
  )
803
 
804
+
805
  if __name__ == "__main__":
806
  if tagger_instance is None:
807
  print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。")
808
+ demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)