IdlecloudX commited on
Commit
a98172e
·
verified ·
1 Parent(s): 9077880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -207
app.py CHANGED
@@ -6,29 +6,22 @@ import numpy as np
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
- from huggingface_hub import login
 
10
 
11
- from translator import translate_texts, translate_texts_with_dynamic_keys
12
 
 
 
 
13
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
14
  MODEL_FILENAME = "model.onnx"
15
  LABEL_FILENAME = "selected_tags.csv"
16
 
17
- OWNER_USERNAME = os.environ.get("OWNER_USERNAME", "").lower()
18
-
19
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
20
- if HF_TOKEN:
21
- try:
22
- login(token=HF_TOKEN)
23
- print("✅ HF_TOKEN 登录成功")
24
- except Exception as e:
25
- print(f"⚠️ HF_TOKEN 登录失败: {e}")
26
- else:
27
- print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
28
-
29
 
30
  # ------------------------------------------------------------------
31
- # Tagger Class (Global Instantiation)
32
  # ------------------------------------------------------------------
33
  class Tagger:
34
  def __init__(self):
@@ -63,11 +56,10 @@ class Tagger:
63
  raise RuntimeError(f"模型初始化失败: {e}")
64
 
65
 
 
66
  def _preprocess(self, img: Image.Image) -> np.ndarray:
67
- if img is None:
68
- raise ValueError("输入图像不能为空")
69
- if img.mode != "RGB":
70
- img = img.convert("RGB")
71
  size = max(img.size)
72
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
73
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
@@ -75,44 +67,36 @@ class Tagger:
75
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
76
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
77
 
 
78
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
79
- if self.model is None:
80
- raise RuntimeError("模型未成功加载,无法进行预测。")
81
  inp_name = self.model.get_inputs()[0].name
82
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
83
 
84
  res = {"ratings": {}, "general": {}, "characters": {}}
85
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
86
 
87
- for idx in self.categories["rating"]:
88
- tag_name = self.tag_names[idx].replace("_", " ")
89
- res["ratings"][tag_name] = float(outputs[idx])
90
- tag_categories_for_translation["ratings"].append(tag_name)
91
-
92
- for idx in self.categories["general"]:
93
- if outputs[idx] > gen_th:
94
- tag_name = self.tag_names[idx].replace("_", " ")
95
- res["general"][tag_name] = float(outputs[idx])
96
- tag_categories_for_translation["general"].append(tag_name)
97
-
98
- for idx in self.categories["character"]:
99
- if outputs[idx] > char_th:
100
- tag_name = self.tag_names[idx].replace("_", " ")
101
- res["characters"][tag_name] = float(outputs[idx])
102
- tag_categories_for_translation["characters"].append(tag_name)
103
-
104
-
105
- res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True))
106
- res["characters"] = dict(sorted(res["characters"].items(), key=lambda kv: kv[1], reverse=True))
107
- res["ratings"] = dict(sorted(res["ratings"].items(), key=lambda kv: kv[1], reverse=True))
108
-
109
 
110
- tag_categories_for_translation["general"] = list(res["general"].keys())
111
- tag_categories_for_translation["characters"] = list(res["characters"].keys())
112
- tag_categories_for_translation["ratings"] = list(res["ratings"].keys())
 
113
 
114
  return res, tag_categories_for_translation
115
 
 
116
  try:
117
  tagger_instance = Tagger()
118
  except RuntimeError as e:
@@ -140,10 +124,13 @@ function copyToClipboard(text) {
140
  }
141
  navigator.clipboard.writeText(text).then(() => {
142
  const feedback = document.createElement('div');
143
- let displayText = String(text);
144
- displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : '');
145
  feedback.textContent = '已复制: ' + displayText;
146
- feedback.style.cssText = 'position: fixed; bottom: 20px; left: 50%; transform: translateX(-50%); background-color: #4CAF50; color: white; padding: 10px 20px; border-radius: 5px; z-index: 10000; transition: opacity 0.5s ease-out;';
 
 
 
 
147
  document.body.appendChild(feedback);
148
  setTimeout(() => {
149
  feedback.style.opacity = '0';
@@ -156,17 +143,16 @@ function copyToClipboard(text) {
156
  """
157
 
158
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
159
- with gr.Row():
160
- login_button = gr.LoginButton()
161
- user_info = gr.Markdown("正在检查登录状态...", visible=True)
162
-
163
  gr.Markdown("# 🖼️ AI 图像标签分析器")
164
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
165
-
 
 
 
 
 
166
  state_res = gr.State({})
167
  state_translations_dict = gr.State({})
168
- state_tag_categories_for_translation = gr.State({})
169
-
170
 
171
  with gr.Row():
172
  with gr.Column(scale=1):
@@ -174,191 +160,186 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
174
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
175
 
176
  with gr.Accordion("⚙️ 高级设置", open=False):
177
- gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准")
178
- char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值")
179
  show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
180
-
181
- with gr.Group(visible=False) as guest_api_group:
182
- gr.Markdown("### 访客翻译API配置\n由于您不是本空间所有者,需要提供自己的翻译API密钥才能使用翻译功能。")
183
- guest_tencent_id = gr.Textbox(label="腾讯云 Secret ID", type="password")
184
- guest_tencent_key = gr.Textbox(label="腾讯云 Secret Key", type="password")
185
- guest_baidu_json = gr.TextArea(
186
- label="百度翻译凭证 JSON",
187
- placeholder='[{"app_id": "...", "secret_key": "..."}, ...]',
188
- lines=3
189
- )
190
-
191
  with gr.Accordion("📊 标签汇总设置", open=True):
192
- gr.Markdown("选择要包含在下方汇总文本框中的标签类别:")
193
- with gr.Row():
194
- sum_general = gr.Checkbox(True, label="通用标签", min_width=50)
195
- sum_char = gr.Checkbox(True, label="角色标签", min_width=50)
196
- sum_rating = gr.Checkbox(False, label="评分标签", min_width=50)
197
- sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符")
198
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
199
 
200
  processing_info = gr.Markdown("", visible=False)
201
 
202
  with gr.Column(scale=2):
203
  with gr.Tabs():
204
- with gr.TabItem("🏷️ 通用标签"):
205
- out_general = gr.HTML(label="General Tags")
206
- with gr.TabItem("👤 角色标签"):
207
- gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签推测基于截至2024年2月的数据。</p>")
208
- out_char = gr.HTML(label="Character Tags")
209
- with gr.TabItem("⭐ 评分标签"):
210
- out_rating = gr.HTML(label="Rating Tags")
211
-
212
  gr.Markdown("### 标签汇总结果")
213
- out_summary = gr.Textbox(
214
- label="标签汇总",
215
- placeholder="分析完成后,此处将显示汇总的英文标签...",
216
- lines=5,
217
- show_copy_button=True
218
- )
219
-
220
- def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  if not tags_dict: return "<p>暂无标签</p>"
222
  html = '<div class="label-container">'
223
- tag_keys = list(tags_dict.keys())
224
- for i, tag in enumerate(tag_keys):
225
- score = tags_dict[tag]
226
  escaped_tag = tag.replace("'", "\\'")
227
  html += '<div class="tag-item">'
228
  tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
229
- if show_translation_in_list and i < len(translations_list) and translations_list[i]:
230
  tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
231
  html += f'<div>{tag_display_html}</div>'
232
- if show_scores: html += f'<span class.tag-score">{score:.3f}</span>'
233
  html += '</div>'
234
- html += '</div>'
235
- return html
236
-
237
- def generate_summary_text_content(current_res, current_translations_dict, s_gen, s_char, s_rat, s_sep_type, s_show_zh):
238
- if not current_res: return "请先分析图像或选择要汇总的标签类别。"
239
- summary_parts, separator = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(s_sep_type, ", ")
240
- categories_to_summarize = []
241
- if s_gen: categories_to_summarize.append("general")
242
- if s_char: categories_to_summarize.append("characters")
243
- if s_rat: categories_to_summarize.append("ratings")
244
- if not categories_to_summarize: return "请至少选择一个标签类别进行汇总。"
245
-
246
- for cat_key in categories_to_summarize:
247
- if current_res.get(cat_key):
248
- tags_to_join = []
249
- cat_tags_en = list(current_res[cat_key].keys())
250
- cat_translations = current_translations_dict.get(cat_key, [])
251
- for i, en_tag in enumerate(cat_tags_en):
252
- if s_show_zh and i < len(cat_translations) and cat_translations[i]:
253
- tags_to_join.append(f"{en_tag}({cat_translations[i]})")
254
- else:
255
- tags_to_join.append(en_tag)
256
- if tags_to_join: summary_parts.append(separator.join(tags_to_join))
257
- joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator
258
- final_summary = joiner.join(summary_parts)
259
- return final_summary if final_summary else "选定的类别中没有找到标签。"
260
-
261
- def update_auth_ui(profile: gr.OAuthProfile | None):
262
- if profile is None:
263
- return gr.update(visible=True), "请先登录..."
264
-
265
- username = profile.username.lower()
266
- is_owner = username == OWNER_USERNAME
267
-
268
- if is_owner:
269
- user_info_md = f"✅ **所有者模式**: 欢迎, {profile.name}! 将使用预置翻译服务。"
270
- else:
271
- user_info_md = f"👋 **访客模式**: 欢迎, {profile.name}! 请在高级设置中提供您自己的翻译密钥。"
272
-
273
- return gr.update(visible=not is_owner), user_info_md
274
-
275
- def process_image_and_generate_outputs(img, g_th, c_th, s_scores, s_gen, s_char, s_rat, s_sep, s_zh_in_sum, guest_tc_id, guest_tc_key, guest_bd_json, profile: gr.OAuthProfile | None):
276
- guest_api_update, user_info_update = update_auth_ui(profile)
277
-
278
- if profile is None:
279
- gr.Warning("请先登录后再进行分析!")
280
- yield (gr.update(), gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), guest_api_update, user_info_update)
281
- return
282
-
283
  if img is None:
284
- yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 请先上传图片。"), "", "", "", gr.update(placeholder="请先上传图片并开始分析..."), {}, {}, {}, guest_api_update, user_info_update)
285
- return
286
  if tagger_instance is None:
287
- yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), "", "", "", gr.update(placeholder="分析器初始化失败..."), {}, {}, {}, guest_api_update, user_info_update)
288
- return
289
-
290
- yield (gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析图像..."), gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"), gr.update(value="分析中..."), {}, {}, {}, guest_api_update, user_info_update)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
 
 
 
 
 
 
 
292
  try:
293
- res, tag_categories_original_order = tagger_instance.predict(img, g_th, c_th)
294
- all_tags_to_translate = [tag for cat_key in ["general", "characters", "ratings"] for tag in tag_categories_original_order.get(cat_key, [])]
295
 
296
- all_translations_flat = []
297
- if all_tags_to_translate:
298
- is_owner = profile and profile.username.lower() == OWNER_USERNAME
299
- if is_owner:
300
- print("- [Auth] 所有者身份,使用预置密钥进行翻译。")
301
- all_translations_flat = translate_texts(all_tags_to_translate)
302
- else:
303
- print("- [Auth] 访客身份,使用用户提供的密钥进行翻译。")
304
- if not guest_tc_id and not guest_bd_json:
305
- print(" - [Warning] 访客未提供任何API密钥,将跳过翻译。")
306
- all_translations_flat = all_tags_to_translate
307
- else:
308
- all_translations_flat = translate_texts_with_dynamic_keys(all_tags_to_translate, guest_tc_id, guest_tc_key, guest_bd_json)
309
-
310
- current_translations_dict = {}
311
- offset = 0
312
- for cat_key in ["general", "characters", "ratings"]:
313
- num_tags = len(tag_categories_original_order.get(cat_key, []))
314
- current_translations_dict[cat_key] = all_translations_flat[offset : offset + num_tags]
315
- offset += num_tags
316
 
317
- general_html = format_tags_html(res.get("general", {}), current_translations_dict.get("general", []), "general", s_scores)
318
- char_html = format_tags_html(res.get("characters", {}), current_translations_dict.get("characters", []), "characters", s_scores)
319
- rating_html = format_tags_html(res.get("ratings", {}), current_translations_dict.get("ratings", []), "ratings", s_scores)
320
- summary_text = generate_summary_text_content(res, current_translations_dict, s_gen, s_char, s_rat, s_sep, s_zh_in_sum)
321
 
322
- yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), general_html, char_html, rating_html, gr.update(value=summary_text), res, current_translations_dict, tag_categories_original_order, guest_api_update, user_info_update)
323
 
324
  except Exception as e:
325
  import traceback
326
- tb_str = traceback.format_exc()
327
- print(f"处理时发生错误: {e}\n{tb_str}")
328
- yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), "<p>处理出错</p>", "<p>处理出错</p>", "<p>处理出错</p>", gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."), {}, {}, {}, guest_api_update, user_info_update)
329
 
330
- def update_summary_display(s_gen, s_char, s_rat, s_sep, s_zh_in_sum, current_res, current_translations):
331
- if not current_res: return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="")
332
- new_summary_text = generate_summary_text_content(current_res, current_translations, s_gen, s_char, s_rat, s_sep, s_zh_in_sum)
333
- return gr.update(value=new_summary_text)
334
-
335
- def check_user_auth(profile: gr.OAuthProfile | None, request: gr.Request = None):
336
- if not OWNER_USERNAME: print("⚠️ 警告: 未设置 OWNER_USERNAME 环境变量。所有用户都将被视为访客。")
337
-
338
- if profile is None:
339
- print("- [Auth] 用户未登录。")
340
- else:
341
- if profile.username.lower() == OWNER_USERNAME:
342
- print(f"- [Auth] 所有者 '{profile.username}' 已连接。")
343
- else:
344
- print(f"- [Auth] 访客 '{profile.username}' 已连接,显示 API Key 输入框。")
345
- return update_auth_ui(profile)
346
 
347
- demo.load(fn=check_user_auth, inputs=[login_button], outputs=[guest_api_group, user_info])
348
-
349
- img_in.upload(fn=update_auth_ui, inputs=[login_button], outputs=[guest_api_group, user_info])
350
-
351
  btn.click(
352
  process_image_and_generate_outputs,
353
- inputs=[img_in, gen_slider, char_slider, show_tag_scores, sum_general, sum_char, sum_rating, sum_sep, sum_show_zh, guest_tencent_id, guest_tencent_key, guest_baidu_json, login_button],
354
- outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary, state_res, state_translations_dict, state_tag_categories_for_translation, guest_api_group, user_info]
 
 
 
 
 
 
 
 
 
355
  )
356
 
357
- summary_controls = [sum_general, sum_char, sum_rating, sum_sep, sum_show_zh]
358
  for ctrl in summary_controls:
359
- ctrl.change(fn=update_summary_display, inputs=summary_controls + [state_res, state_translations_dict], outputs=[out_summary])
 
 
 
 
360
 
361
  if __name__ == "__main__":
362
  if tagger_instance is None:
363
- print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。")
364
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
+ from huggingface_hub import whoami, get_space_runtime, HfApi
10
+ from huggingface_hub.hf_api import SpaceRuntime
11
 
12
+ from translator import translate_texts
13
 
14
+ # ------------------------------------------------------------------
15
+ # 模型配置
16
+ # ------------------------------------------------------------------
17
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
18
  MODEL_FILENAME = "model.onnx"
19
  LABEL_FILENAME = "selected_tags.csv"
20
 
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # ------------------------------------------------------------------
24
+ # Tagger (全局实例化)
25
  # ------------------------------------------------------------------
26
  class Tagger:
27
  def __init__(self):
 
56
  raise RuntimeError(f"模型初始化失败: {e}")
57
 
58
 
59
+ # ------------------------- preprocess -------------------------
60
  def _preprocess(self, img: Image.Image) -> np.ndarray:
61
+ if img is None: raise ValueError("输入图像不能为空")
62
+ if img.mode != "RGB": img = img.convert("RGB")
 
 
63
  size = max(img.size)
64
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
65
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
 
67
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
68
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
69
 
70
+ # --------------------------- predict --------------------------
71
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
72
+ if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
 
73
  inp_name = self.model.get_inputs()[0].name
74
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
75
 
76
  res = {"ratings": {}, "general": {}, "characters": {}}
77
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
78
 
79
+ for cat_key, cat_indices in self.categories.items():
80
+ sub_res = {}
81
+ if cat_key == "rating":
82
+ for idx in cat_indices:
83
+ tag_name = self.tag_names[idx].replace("_", " ")
84
+ sub_res[tag_name] = float(outputs[idx])
85
+ else:
86
+ threshold = char_th if cat_key == "character" else gen_th
87
+ for idx in cat_indices:
88
+ if outputs[idx] > threshold:
89
+ tag_name = self.tag_names[idx].replace("_", " ")
90
+ sub_res[tag_name] = float(outputs[idx])
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Use the correct key for 'character'
93
+ res_key = "characters" if cat_key == "character" else cat_key
94
+ res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
95
+ tag_categories_for_translation[res_key] = list(res[res_key].keys())
96
 
97
  return res, tag_categories_for_translation
98
 
99
+ # 全局 Tagger 实例
100
  try:
101
  tagger_instance = Tagger()
102
  except RuntimeError as e:
 
124
  }
125
  navigator.clipboard.writeText(text).then(() => {
126
  const feedback = document.createElement('div');
127
+ let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : '');
 
128
  feedback.textContent = '已复制: ' + displayText;
129
+ Object.assign(feedback.style, {
130
+ position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
131
+ backgroundColor: '#4CAF50', color: 'white', padding: '10px 20px',
132
+ borderRadius: '5px', zIndex: '10000', transition: 'opacity 0.5s ease-out'
133
+ });
134
  document.body.appendChild(feedback);
135
  setTimeout(() => {
136
  feedback.style.opacity = '0';
 
143
  """
144
 
145
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
 
 
 
 
146
  gr.Markdown("# 🖼️ AI 图像标签分析器")
147
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
148
+
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
152
+ user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")
153
+
154
  state_res = gr.State({})
155
  state_translations_dict = gr.State({})
 
 
156
 
157
  with gr.Row():
158
  with gr.Column(scale=1):
 
160
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
161
 
162
  with gr.Accordion("⚙️ 高级设置", open=False):
163
+ gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
164
+ char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
165
  show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
166
+
167
+ with gr.Accordion("🔑 自定义翻译密钥 (可选)", open=False, visible=False) as api_key_accordion:
168
+ gr.Markdown("如果你不是空间所有者,需要在这里提供自己的API密钥才能使用翻译功能。")
169
+ tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1)
170
+ tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password")
171
+ baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]')
172
+
 
 
 
 
173
  with gr.Accordion("📊 标签汇总设置", open=True):
174
+ sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["��用标签", "角色标签"], label="汇总类别")
175
+ sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符")
 
 
 
 
176
  sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
177
 
178
  processing_info = gr.Markdown("", visible=False)
179
 
180
  with gr.Column(scale=2):
181
  with gr.Tabs():
182
+ with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags")
183
+ with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags")
184
+ with gr.TabItem(" 评分标签"): out_rating = gr.HTML(label="Rating Tags")
 
 
 
 
 
185
  gr.Markdown("### 标签汇总结果")
186
+ out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
187
+
188
+ # ----------------- 辅助函数 -----------------
189
+ def get_token_from_request(request: gr.Request) -> str | None:
190
+ auth_header = request.headers.get("authorization")
191
+ if auth_header and auth_header.startswith("Bearer "):
192
+ return auth_header.split(" ")[1]
193
+ return None
194
+
195
+ def check_user_is_owner(user_info: dict | None, space_runtime: SpaceRuntime | None) -> bool:
196
+ """检查用户是否是空间的所有者(处理个人和组织空间)。"""
197
+ if not user_info or not space_runtime:
198
+ return False
199
+
200
+ # 情况1:空间由用户直接拥有
201
+ if user_info.get("name") == space_runtime.owner:
202
+ return True
203
+
204
+ # 情况2:空间由用户所属的组织拥有
205
+ user_orgs = user_info.get("orgs", [])
206
+ if any(org.get("name") == space_runtime.owner for org in user_orgs):
207
+ return True
208
+
209
+ return False
210
+
211
+ def check_user_status(request: gr.Request):
212
+ token = get_token_from_request(request)
213
+ if token:
214
+ try:
215
+ user_info = whoami(token=token)
216
+ space_runtime = get_space_runtime()
217
+
218
+ if check_user_is_owner(user_info, space_runtime):
219
+ return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
220
+ else:
221
+ return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
222
+ except Exception as e:
223
+ print(f"获取用户信息时出错: {e}")
224
+ return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
225
+ return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
226
+
227
+ def format_tags_html(tags_dict, translations_list, show_scores):
228
  if not tags_dict: return "<p>暂无标签</p>"
229
  html = '<div class="label-container">'
230
+ for i, (tag, score) in enumerate(tags_dict.items()):
 
 
231
  escaped_tag = tag.replace("'", "\\'")
232
  html += '<div class="tag-item">'
233
  tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
234
+ if i < len(translations_list) and translations_list[i]:
235
  tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
236
  html += f'<div>{tag_display_html}</div>'
237
+ if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
238
  html += '</div>'
239
+ return html + '</div>'
240
+
241
+ def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh):
242
+ if not current_res: return "请先分析图像。"
243
+ parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ")
244
+ cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"}
245
+ for cat_name in sum_cats:
246
+ cat_key = cat_map.get(cat_name)
247
+ if cat_key and current_res.get(cat_key):
248
+ tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, [])
249
+ tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)]
250
+ if tags_to_join: parts.append(sep.join(tags_to_join))
251
+ return "\n".join(parts) if parts else "选定的类别中没有找到标签。"
252
+
253
+ # ----------------- 主要处理回调 -----------------
254
+ def process_image_and_generate_outputs(
255
+ img, g_th, c_th, s_scores,
256
+ user_tencent_id, user_tencent_key, user_baidu_json,
257
+ sum_cats, s_sep, s_zh_in_sum,
258
+ request: gr.Request
259
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  if img is None:
261
+ raise gr.Error("请先上传图片。")
 
262
  if tagger_instance is None:
263
+ raise gr.Error("分析器未成功初始化,请检查后台错误。")
264
+
265
+ yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
266
+
267
+ token = get_token_from_request(request)
268
+ is_owner = False
269
+ if token:
270
+ try:
271
+ user_info = whoami(token=token)
272
+ space_runtime = get_space_runtime()
273
+ if check_user_is_owner(user_info, space_runtime):
274
+ is_owner = True
275
+ except Exception: pass
276
+
277
+ final_tencent_id, final_tencent_key, baidu_json_str = (
278
+ (os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]"))
279
+ if is_owner else (user_tencent_id, user_tencent_key, user_baidu_json)
280
+ )
281
 
282
+ final_baidu_creds_list = []
283
+ if baidu_json_str and baidu_json_str.strip():
284
+ try:
285
+ parsed_data = json.loads(baidu_json_str)
286
+ if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data
287
+ except json.JSONDecodeError: print("提供的百度凭证JSON无效。")
288
+
289
  try:
290
+ res, tag_cats_original = tagger_instance.predict(img, g_th, c_th)
291
+ all_tags = [tag for cat in tag_cats_original.values() for tag in cat]
292
 
293
+ translations_flat = translate_texts(
294
+ all_tags,
295
+ tencent_secret_id=final_tencent_id,
296
+ tencent_secret_key=final_tencent_key,
297
+ baidu_credentials_list=final_baidu_creds_list
298
+ ) if all_tags else []
299
+
300
+ translations, offset = {}, 0
301
+ for cat_key, tags in tag_cats_original.items():
302
+ translations[cat_key] = translations_flat[offset : offset + len(tags)]
303
+ offset += len(tags)
 
 
 
 
 
 
 
 
 
304
 
305
+ outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]}
306
+ summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum)
 
 
307
 
308
+ yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations
309
 
310
  except Exception as e:
311
  import traceback
312
+ traceback.print_exc()
313
+ raise gr.Error(f"处理时发生错误: {e}")
 
314
 
315
+ # ----------------- 绑定事件 -----------------
316
+ demo.load(fn=check_user_status, inputs=None, outputs=[user_status_md, api_key_accordion], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
 
 
 
 
318
  btn.click(
319
  process_image_and_generate_outputs,
320
+ inputs=[
321
+ img_in, gen_slider, char_slider, show_tag_scores,
322
+ tencent_id_in, tencent_key_in, baidu_json_in,
323
+ sum_cats, sum_sep, sum_show_zh
324
+ ],
325
+ outputs=[
326
+ btn, processing_info,
327
+ out_general, out_char, out_rating,
328
+ out_summary,
329
+ state_res, state_translations_dict
330
+ ],
331
  )
332
 
333
+ summary_controls = [sum_cats, sum_sep, sum_show_zh]
334
  for ctrl in summary_controls:
335
+ ctrl.change(
336
+ fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z),
337
+ inputs=[state_res, state_translations_dict] + summary_controls,
338
+ outputs=[out_summary],
339
+ )
340
 
341
  if __name__ == "__main__":
342
  if tagger_instance is None:
343
+ print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
344
+ demo.launch(server_name="0.0.0.0", server_port=7860)
345
+