IdlecloudX commited on
Commit
07ef9c9
·
verified ·
1 Parent(s): 0599754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -234
app.py CHANGED
@@ -6,26 +6,22 @@ import numpy as np
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
- from huggingface_hub import whoami, HfApi
10
- from translator import translate_texts
11
 
12
- # ------------------------------------------------------------------
13
- # Model Configuration
14
- # ------------------------------------------------------------------
15
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
16
  MODEL_FILENAME = "model.onnx"
17
  LABEL_FILENAME = "selected_tags.csv"
 
 
 
 
 
 
 
 
 
18
 
19
- # It's recommended to manage the token within the HF Spaces secrets
20
- HF_TOKEN = os.environ.get("HF_TOKEN")
21
- # A more robust way to get the space owner
22
- SPACE_ID = os.environ.get("SPACE_ID")
23
- SPACE_OWNER = SPACE_ID.split('/')[0] if SPACE_ID else None
24
-
25
-
26
- # ------------------------------------------------------------------
27
- # Tagger Class (Global Instance)
28
- # ------------------------------------------------------------------
29
  class Tagger:
30
  def __init__(self):
31
  self.hf_token = HF_TOKEN
@@ -53,15 +49,17 @@ class Tagger:
53
  }
54
  self.model = rt.InferenceSession(model_path)
55
  self.input_size = self.model.get_inputs()[0].shape[1]
56
- print("✅ Model and labels loaded successfully.")
57
  except Exception as e:
58
- print(f"❌ Failed to load model or labels: {e}")
59
- raise RuntimeError(f"Model initialization failed: {e}")
 
60
 
61
- # ------------------------- preprocess -------------------------
62
  def _preprocess(self, img: Image.Image) -> np.ndarray:
63
- if img is None: raise ValueError("Input image cannot be None.")
64
- if img.mode != "RGB": img = img.convert("RGB")
 
 
65
  size = max(img.size)
66
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
67
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
@@ -69,39 +67,48 @@ class Tagger:
69
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
70
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
71
 
72
- # --------------------------- predict --------------------------
73
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
74
- if self.model is None: raise RuntimeError("Model not loaded, cannot predict.")
 
75
  inp_name = self.model.get_inputs()[0].name
76
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
77
 
78
  res = {"ratings": {}, "general": {}, "characters": {}}
79
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
80
 
81
- for cat_key, cat_indices in self.categories.items():
82
- sub_res = {}
83
- if cat_key == "rating":
84
- for idx in cat_indices:
85
- tag_name = self.tag_names[idx].replace("_", " ")
86
- sub_res[tag_name] = float(outputs[idx])
87
- else:
88
- threshold = char_th if cat_key == "character" else gen_th
89
- for idx in cat_indices:
90
- if outputs[idx] > threshold:
91
- tag_name = self.tag_names[idx].replace("_", " ")
92
- sub_res[tag_name] = float(outputs[idx])
93
-
94
- res_key = "characters" if cat_key == "character" else cat_key
95
- res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
96
- tag_categories_for_translation[res_key] = list(res[res_key].keys())
 
 
 
 
 
 
 
 
 
 
97
 
98
  return res, tag_categories_for_translation
99
 
100
- # Global Tagger instance
101
  try:
102
  tagger_instance = Tagger()
103
  except RuntimeError as e:
104
- print(f"Tagger initialization failed on app startup: {e}")
105
  tagger_instance = None
106
 
107
  # ------------------------------------------------------------------
@@ -115,6 +122,7 @@ custom_css = """
115
  .tag-zh { color: #666; margin-left: 10px; }
116
  .tag-score { color: #999; font-size: 0.9em; }
117
  .btn-analyze-container { margin-top: 15px; margin-bottom: 15px; }
 
118
  """
119
 
120
  _js_functions = """
@@ -125,217 +133,246 @@ function copyToClipboard(text) {
125
  }
126
  navigator.clipboard.writeText(text).then(() => {
127
  const feedback = document.createElement('div');
128
- let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : '');
 
129
  feedback.textContent = '已复制: ' + displayText;
130
- Object.assign(feedback.style, {
131
- position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
132
- backgroundColor: '#4CAF50', color: 'white', padding: '10px 20px',
133
- borderRadius: '5px', zIndex: '10000', transition: 'opacity 0.5s ease-out'
134
- });
135
  document.body.appendChild(feedback);
136
  setTimeout(() => {
137
  feedback.style.opacity = '0';
138
- setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500);
139
  }, 1500);
140
  }).catch(err => {
141
- console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text);
142
  });
143
  }
144
  """
145
 
146
- with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
147
- gr.Markdown("# 🖼️ AI 图像标签分析器")
148
- gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
149
-
150
- with gr.Row():
151
- with gr.Column(scale=1):
152
- login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
153
- user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")
154
-
155
- state_res = gr.State({})
156
- state_translations_dict = gr.State({})
157
-
158
- with gr.Row():
159
- with gr.Column(scale=1):
160
- img_in = gr.Image(type="pil", label="上传图片", height=300)
161
- btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
162
-
163
- with gr.Accordion("⚙️ 高级设置", open=False):
164
- gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
165
- char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
166
- show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
167
-
168
- with gr.Accordion("🔑 自定义翻译密钥 (可选)", open=False, visible=False) as api_key_accordion:
169
- gr.Markdown("如果你不是空间所有者,需要在这里提供自己的API密钥才能使用翻译功能。")
170
- tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1)
171
- tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password")
172
- baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]')
173
-
174
- with gr.Accordion("📊 标签汇总设置", open=True):
175
- sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["通用标签", "角色标签"], label="汇总类别")
176
- sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符")
177
- sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
178
-
179
- processing_info = gr.Markdown("", visible=False)
180
-
181
- with gr.Column(scale=2):
182
- with gr.Tabs():
183
- with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags")
184
- with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags")
185
- with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags")
186
- gr.Markdown("### 标签汇总结果")
187
- out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
188
-
189
- def get_token_from_request(request: gr.Request) -> str | None:
190
- auth_header = request.headers.get("authorization")
191
- if auth_header and auth_header.startswith("Bearer "):
192
- return auth_header.split(" ")[1]
193
- return None
194
-
195
- def is_user_space_owner(user_info: dict | None) -> bool:
196
- """
197
- Robustly checks if the user is the owner of the space by parsing SPACE_ID.
198
- """
199
- if not user_info or not SPACE_OWNER:
200
- if not SPACE_OWNER:
201
- print("⚠️ Warning: SPACE_ID environment variable not found.")
202
- return False
203
-
204
- user_name = user_info.get("name")
205
- user_orgs = [org.get("name") for org in user_info.get("orgs", [])]
206
-
207
- print(f"ℹ️ [Auth Check] Space Owner: '{SPACE_OWNER}', User: '{user_name}', User Orgs: {user_orgs}")
208
-
209
- is_owner = (user_name == SPACE_OWNER) or (SPACE_OWNER in user_orgs)
210
- return is_owner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- def check_user_status(request: gr.Request):
213
- token = get_token_from_request(request)
214
- if token:
215
  try:
216
- user_info = whoami(token=token)
 
 
 
 
 
 
 
217
 
218
- if is_user_space_owner(user_info):
219
- return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
220
- else:
221
- return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
 
 
 
 
 
 
 
 
 
 
222
  except Exception as e:
223
- print(f"Error getting user info: {e}")
224
- return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
225
- return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- def format_tags_html(tags_dict, translations_list, show_scores):
228
- if not tags_dict: return "<p>暂无标签</p>"
229
- html = '<div class="label-container">'
230
- for i, (tag, score) in enumerate(tags_dict.items()):
231
- escaped_tag = tag.replace("'", "\\'")
232
- html += '<div class="tag-item">'
233
- tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
234
- if i < len(translations_list) and translations_list[i]:
235
- tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
236
- html += f'<div>{tag_display_html}</div>'
237
- if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
238
- html += '</div>'
239
- return html + '</div>'
240
-
241
- def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh):
242
- if not current_res: return "请先分析图像。"
243
- parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ")
244
- cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"}
245
- for cat_name in sum_cats:
246
- cat_key = cat_map.get(cat_name)
247
- if cat_key and current_res.get(cat_key):
248
- tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, [])
249
- tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)]
250
- if tags_to_join: parts.append(sep.join(tags_to_join))
251
- return "\n".join(parts) if parts else "选定的类别中没有找到标签。"
252
-
253
- def process_image_and_generate_outputs(
254
- img, g_th, c_th, s_scores,
255
- user_tencent_id, user_tencent_key, user_baidu_json,
256
- sum_cats, s_sep, s_zh_in_sum,
257
- request: gr.Request
258
- ):
259
- if img is None:
260
- raise gr.Error("请先上传图片。")
261
- if tagger_instance is None:
262
- raise gr.Error("分析器未成功初始化,请检查后台错误。")
 
 
 
 
 
 
 
 
 
263
 
264
- yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
 
265
 
266
- token = get_token_from_request(request)
267
- is_owner = False
268
- if token:
269
- try:
270
- user_info = whoami(token=token)
271
- if is_user_space_owner(user_info):
272
- is_owner = True
273
- except Exception: pass
274
-
275
- final_tencent_id, final_tencent_key, baidu_json_str = (
276
- (os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]"))
277
- if is_owner else (user_tencent_id, user_tencent_key, user_baidu_json)
278
- )
279
-
280
- final_baidu_creds_list = []
281
- if baidu_json_str and baidu_json_str.strip():
282
- try:
283
- parsed_data = json.loads(baidu_json_str)
284
- if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data
285
- except json.JSONDecodeError: print("提供的百度凭证JSON无效。")
286
 
287
- try:
288
- res, tag_cats_original = tagger_instance.predict(img, g_th, c_th)
289
- all_tags = [tag for cat in tag_cats_original.values() for tag in cat]
290
-
291
- translations_flat = translate_texts(
292
- all_tags,
293
- tencent_secret_id=final_tencent_id,
294
- tencent_secret_key=final_tencent_key,
295
- baidu_credentials_list=final_baidu_creds_list
296
- ) if all_tags else []
297
-
298
- translations, offset = {}, 0
299
- for cat_key, tags in tag_cats_original.items():
300
- translations[cat_key] = translations_flat[offset : offset + len(tags)]
301
- offset += len(tags)
302
-
303
- outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]}
304
- summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum)
305
 
306
- yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations
307
-
308
- except Exception as e:
309
- import traceback
310
- traceback.print_exc()
311
- raise gr.Error(f"处理时发生错误: {e}")
312
-
313
- demo.load(fn=check_user_status, inputs=None, outputs=[user_status_md, api_key_accordion], queue=False)
314
-
315
- btn.click(
316
- process_image_and_generate_outputs,
317
- inputs=[
318
- img_in, gen_slider, char_slider, show_tag_scores,
319
- tencent_id_in, tencent_key_in, baidu_json_in,
320
- sum_cats, sum_sep, sum_show_zh
321
- ],
322
- outputs=[
323
- btn, processing_info,
324
- out_general, out_char, out_rating,
325
- out_summary,
326
- state_res, state_translations_dict
327
- ],
328
- )
329
-
330
- summary_controls = [sum_cats, sum_sep, sum_show_zh]
331
- for ctrl in summary_controls:
332
- ctrl.change(
333
- fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z),
334
- inputs=[state_res, state_translations_dict] + summary_controls,
335
- outputs=[out_summary],
336
- )
337
-
338
  if __name__ == "__main__":
339
  if tagger_instance is None:
340
- print("CRITICAL: Tagger failed to initialize. App functionality will be limited.")
341
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
+ from huggingface_hub import login, HfApi
10
+ from translator import translate_texts, set_user_provided_keys, clear_user_provided_keys
11
 
 
 
 
12
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
13
  MODEL_FILENAME = "model.onnx"
14
  LABEL_FILENAME = "selected_tags.csv"
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
16
+ if HF_TOKEN:
17
+ try:
18
+ login(token=HF_TOKEN)
19
+ print("✅ 应用已使用 HF_TOKEN 登录")
20
+ except Exception as e:
21
+ print(f"⚠️ 使用 HF_TOKEN 登录失败: {e}")
22
+ else:
23
+ print("⚠️ 未检测到应用级别的 HF_TOKEN,私有模型可能下载失败")
24
 
 
 
 
 
 
 
 
 
 
 
25
  class Tagger:
26
  def __init__(self):
27
  self.hf_token = HF_TOKEN
 
49
  }
50
  self.model = rt.InferenceSession(model_path)
51
  self.input_size = self.model.get_inputs()[0].shape[1]
52
+ print("✅ 模型和标签加载成功")
53
  except Exception as e:
54
+ print(f"❌ 模型或标签加载失败: {e}")
55
+ raise RuntimeError(f"模型初始化失败: {e}")
56
+
57
 
 
58
  def _preprocess(self, img: Image.Image) -> np.ndarray:
59
+ if img is None:
60
+ raise ValueError("输入图像不能为空")
61
+ if img.mode != "RGB":
62
+ img = img.convert("RGB")
63
  size = max(img.size)
64
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
65
  canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
 
67
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
68
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
69
 
 
70
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
71
+ if self.model is None:
72
+ raise RuntimeError("模型未成功加载,无法进行预测。")
73
  inp_name = self.model.get_inputs()[0].name
74
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
75
 
76
  res = {"ratings": {}, "general": {}, "characters": {}}
77
  tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
78
 
79
+ for idx in self.categories["rating"]:
80
+ tag_name = self.tag_names[idx].replace("_", " ")
81
+ res["ratings"][tag_name] = float(outputs[idx])
82
+ tag_categories_for_translation["ratings"].append(tag_name)
83
+
84
+ for idx in self.categories["general"]:
85
+ if outputs[idx] > gen_th:
86
+ tag_name = self.tag_names[idx].replace("_", " ")
87
+ res["general"][tag_name] = float(outputs[idx])
88
+ tag_categories_for_translation["general"].append(tag_name)
89
+
90
+ for idx in self.categories["character"]:
91
+ if outputs[idx] > char_th:
92
+ tag_name = self.tag_names[idx].replace("_", " ")
93
+ res["characters"][tag_name] = float(outputs[idx])
94
+ tag_categories_for_translation["characters"].append(tag_name)
95
+
96
+
97
+ res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True))
98
+ res["characters"] = dict(sorted(res["characters"].items(), key=lambda kv: kv[1], reverse=True))
99
+ res["ratings"] = dict(sorted(res["ratings"].items(), key=lambda kv: kv[1], reverse=True))
100
+
101
+
102
+ tag_categories_for_translation["general"] = list(res["general"].keys())
103
+ tag_categories_for_translation["characters"] = list(res["characters"].keys())
104
+ tag_categories_for_translation["ratings"] = list(res["ratings"].keys())
105
 
106
  return res, tag_categories_for_translation
107
 
 
108
  try:
109
  tagger_instance = Tagger()
110
  except RuntimeError as e:
111
+ print(f"应用启动时Tagger初始化失败: {e}")
112
  tagger_instance = None
113
 
114
  # ------------------------------------------------------------------
 
122
  .tag-zh { color: #666; margin-left: 10px; }
123
  .tag-score { color: #999; font-size: 0.9em; }
124
  .btn-analyze-container { margin-top: 15px; margin-bottom: 15px; }
125
+ .user-info { text-align: right; color: #666; font-size: 0.9em; padding: 5px; }
126
  """
127
 
128
  _js_functions = """
 
133
  }
134
  navigator.clipboard.writeText(text).then(() => {
135
  const feedback = document.createElement('div');
136
+ let displayText = String(text);
137
+ displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : '');
138
  feedback.textContent = '已复制: ' + displayText;
139
+ feedback.style.cssText = 'position:fixed; bottom:20px; left:50%; transform:translateX(-50%); background-color:#4CAF50; color:white; padding:10px 20px; border-radius:5px; z-index:10000; transition:opacity 0.5s ease-out;';
 
 
 
 
140
  document.body.appendChild(feedback);
141
  setTimeout(() => {
142
  feedback.style.opacity = '0';
143
+ setTimeout(() => { document.body.removeChild(feedback); }, 500);
144
  }, 1500);
145
  }).catch(err => {
146
+ console.error('Failed to copy tag.', err, 'Text:', text);
147
  });
148
  }
149
  """
150
 
151
+ def main_interface(user_info: gr.UserInfo):
152
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, js=_js_functions) as demo:
153
+ gr.Markdown(f"<div class='user-info'>已登录: {user_info.name} ({user_info.email})</div>")
154
+ gr.Markdown("# 🖼️ AI 图像标签分析器")
155
+ gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
156
+
157
+ state_res = gr.State({})
158
+ state_translations_dict = gr.State({})
159
+ state_tag_categories_for_translation = gr.State({})
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ img_in = gr.Image(type="pil", label="上传图片", height=300)
164
+ btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
165
+
166
+ with gr.Accordion("⚙️ 高级设置", open=False):
167
+ gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准")
168
+ char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值")
169
+ show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
170
+
171
+ with gr.Accordion("🔑 翻译服务设置", open=False):
172
+ gr.Markdown("如果应用配置了全局翻译密钥,可在此输入访问密钥以使用。否则,请在此处填入您自己的翻译API密钥。")
173
+ access_key_input = gr.Textbox(label="访问密钥 (Access Key)", type="password", placeholder="如果需要,请输入访问密钥")
174
+
175
+ gr.Markdown("---")
176
+ gr.Markdown("**或者**,使用你自己的密钥:")
177
+ user_tencent_id = gr.Textbox(label="腾讯云 Secret ID", type="password")
178
+ user_tencent_key = gr.Textbox(label="腾讯云 Secret Key", type="password")
179
+ user_baidu_json = gr.Textbox(label="百度翻译凭证 (JSON格式)", type="password", lines=3, placeholder='[{"app_id":"...", "secret_key":"..."}]')
180
+
181
+ with gr.Accordion("📊 标签汇总设置", open=True):
182
+ gr.Markdown("选择要包含在下方汇总文本框中的标签类别:")
183
+ with gr.Row():
184
+ sum_general = gr.Checkbox(True, label="通用标签", min_width=50)
185
+ sum_char = gr.Checkbox(True, label="角色标签", min_width=50)
186
+ sum_rating = gr.Checkbox(False, label="评分标签", min_width=50)
187
+ sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符")
188
+ sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
189
+
190
+ processing_info = gr.Markdown("", visible=False)
191
+
192
+ with gr.Column(scale=2):
193
+ with gr.Tabs():
194
+ with gr.TabItem("🏷️ 通用标签"):
195
+ out_general = gr.HTML(label="General Tags")
196
+ with gr.TabItem("👤 角色标签"):
197
+ gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签推测基于截至2024年2月的数据。</p>")
198
+ out_char = gr.HTML(label="Character Tags")
199
+ with gr.TabItem("⭐ 评分标签"):
200
+ out_rating = gr.HTML(label="Rating Tags")
201
+
202
+ gr.Markdown("### 标签汇总结果")
203
+ out_summary = gr.Textbox(label="标签汇总", placeholder="分析完成后,此处将显示汇总的英文标签...", lines=5, show_copy_button=True)
204
+
205
+ def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True):
206
+ if not tags_dict: return "<p>暂无标签</p>"
207
+ html = '<div class="label-container">'
208
+ if not isinstance(translations_list, list): translations_list = []
209
+ tag_keys = list(tags_dict.keys())
210
+ for i, tag in enumerate(tag_keys):
211
+ score = tags_dict[tag]
212
+ escaped_tag = tag.replace("'", "\\'")
213
+ html += '<div class="tag-item">'
214
+ tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
215
+ if show_translation_in_list and i < len(translations_list) and translations_list[i]:
216
+ tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
217
+ html += f'<div>{tag_display_html}</div>'
218
+ if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
219
+ html += '</div>'
220
+ html += '</div>'
221
+ return html
222
+
223
+ def generate_summary_text_content(current_res, current_translations_dict, s_gen, s_char, s_rat, s_sep_type, s_show_zh):
224
+ if not current_res: return "请先分析图像或选择要汇总的标签类别。"
225
+ summary_parts = []
226
+ separator = {"逗号": ", ", "换行": "\n", "空格": " "}.get(s_sep_type, ", ")
227
+ categories_to_summarize = []
228
+ if s_gen: categories_to_summarize.append("general")
229
+ if s_char: categories_to_summarize.append("characters")
230
+ if s_rat: categories_to_summarize.append("ratings")
231
+ if not categories_to_summarize: return "请至少选择一个标签类别进行汇总。"
232
+ for cat_key in categories_to_summarize:
233
+ if current_res.get(cat_key):
234
+ tags_to_join = []
235
+ cat_tags_en = list(current_res[cat_key].keys())
236
+ cat_translations = current_translations_dict.get(cat_key, [])
237
+ for i, en_tag in enumerate(cat_tags_en):
238
+ if s_show_zh and i < len(cat_translations) and cat_translations[i]:
239
+ tags_to_join.append(f"{en_tag}({cat_translations[i]})")
240
+ else:
241
+ tags_to_join.append(en_tag)
242
+ if tags_to_join: summary_parts.append(separator.join(tags_to_join))
243
+ joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator if separator == "\n" else " "
244
+ final_summary = joiner.join(summary_parts)
245
+ return final_summary if final_summary else "选定的类别中没有找到标签。"
246
+
247
+ def process_image_and_generate_outputs(
248
+ img, g_th, c_th, s_scores,
249
+ s_gen, s_char, s_rat, s_sep, s_zh_in_sum,
250
+ access_key, u_tencent_id, u_tencent_key, u_baidu_json,
251
+ request: gr.Request
252
+ ):
253
+ if img is None:
254
+ yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 请先上传图片。"), "", "", "", "", gr.update(placeholder="请先上传图片并开始分析..."), {}, {}, {})
255
+ return
256
+
257
+ if tagger_instance is None:
258
+ yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), "", "", "", "", gr.update(placeholder="分析器初始化失败..."), {}, {}, {})
259
+ return
260
 
261
+ yield (gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析图像,请稍候..."), gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"), gr.HTML(value="<p>分析中...</p>"), gr.update(value="分析中,请稍候..."), {}, {}, {})
262
+
 
263
  try:
264
+ set_user_provided_keys(u_tencent_id, u_tencent_key, u_baidu_json)
265
+
266
+ res, tag_categories_original_order = tagger_instance.predict(img, g_th, c_th)
267
+ all_tags_to_translate = [tag for cat in tag_categories_original_order.values() for tag in cat]
268
+
269
+ all_translations_flat = []
270
+ if all_tags_to_translate:
271
+ all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh", access_key=access_key)
272
 
273
+ current_translations_dict = {}
274
+ offset = 0
275
+ for cat_key in ["general", "characters", "ratings"]:
276
+ num_tags_in_cat = len(tag_categories_original_order.get(cat_key, []))
277
+ current_translations_dict[cat_key] = all_translations_flat[offset : offset + num_tags_in_cat]
278
+ offset += num_tags_in_cat
279
+
280
+ general_html = format_tags_html(res.get("general", {}), current_translations_dict.get("general", []), "general", s_scores, True)
281
+ char_html = format_tags_html(res.get("characters", {}), current_translations_dict.get("characters", []), "characters", s_scores, True)
282
+ rating_html = format_tags_html(res.get("ratings", {}), current_translations_dict.get("ratings", []), "ratings", s_scores, True)
283
+ summary_text = generate_summary_text_content(res, current_translations_dict, s_gen, s_char, s_rat, s_sep, s_zh_in_sum)
284
+
285
+ yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), general_html, char_html, rating_html, gr.update(value=summary_text), res, current_translations_dict, tag_categories_original_order)
286
+
287
  except Exception as e:
288
+ import traceback
289
+ tb_str = traceback.format_exc()
290
+ print(f"处理时发生错误: {e}\n{tb_str}")
291
+ yield (gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), "<p>处理出错</p>", "<p>处理出错</p>", "<p>处理出错</p>", gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."), {}, {}, {})
292
+ finally:
293
+ clear_user_provided_keys()
294
+
295
+ def update_summary_display(s_gen, s_char, s_rat, s_sep, s_zh_in_sum, current_res_from_state, current_translations_from_state):
296
+ if not current_res_from_state:
297
+ return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="")
298
+ new_summary_text = generate_summary_text_content(current_res_from_state, current_translations_from_state, s_gen, s_char, s_rat, s_sep, s_zh_in_sum)
299
+ return gr.update(value=new_summary_text)
300
+
301
+ btn.click(
302
+ process_image_and_generate_outputs,
303
+ inputs=[
304
+ img_in, gen_slider, char_slider, show_tag_scores,
305
+ sum_general, sum_char, sum_rating, sum_sep, sum_show_zh,
306
+ access_key_input, user_tencent_id, user_tencent_key, user_baidu_json
307
+ ],
308
+ outputs=[
309
+ btn, processing_info,
310
+ out_general, out_char, out_rating,
311
+ out_summary,
312
+ state_res, state_translations_dict, state_tag_categories_for_translation
313
+ ]
314
+ )
315
 
316
+ summary_controls = [sum_general, sum_char, sum_rating, sum_sep, sum_show_zh]
317
+ for ctrl in summary_controls:
318
+ ctrl.change(fn=update_summary_display, inputs=summary_controls + [state_res, state_translations_dict], outputs=[out_summary])
319
+
320
+ return demo
321
+
322
+ with gr.Blocks(title="登录到图像标签分析器") as demo:
323
+ CLIENT_ID = os.environ.get("HUGGING_FACE_CLIENT_ID")
324
+ if not CLIENT_ID:
325
+ gr.Markdown("# 错误:应用未配置 OIDC 客户端ID\n请在 Space secrets 中设置 `HUGGING_FACE_CLIENT_ID`")
326
+ else:
327
+ gr.Markdown("# 欢迎使用 AI 图像标签分析器\n请通过 Hugging Face 登录以继续")
328
+ login_button = gr.LoginButton(
329
+ value="🤗 通过 Hugging Face 登录",
330
+ oauth_client_id=CLIENT_ID,
331
+ oauth_scopes=["openid", "profile", "email"],
332
+ oauth_redirect_uri=f"https://huggingface.co/spaces/{os.environ.get('SPACE_ID')}"
333
+ )
334
+ user_info_state = gr.State()
335
+ login_button.login(lambda: None, None, None, js="""
336
+ (btn) => {
337
+ const url = new URL(window.location);
338
+ if (url.searchParams.has('code')) {
339
+ btn.style.display = 'none';
340
+ }
341
+ return btn;
342
+ }
343
+ """)
344
+ demo.load(
345
+ fn=lambda request: request.auth,
346
+ inputs=gr.Request(inputs=[]),
347
+ outputs=user_info_state,
348
+ queue=False,
349
+ js="""
350
+ (request) => {
351
+ const url = new URL(window.location);
352
+ if (!url.searchParams.has('code') && !request.auth) {
353
+ } else {
354
+ document.getElementById('login-interface').style.display = 'none';
355
+ document.getElementById('main-app-interface').style.display = 'block';
356
+ }
357
+ return request;
358
+ }
359
+ """
360
+ )
361
 
362
+ with gr.Column(elem_id="login-interface", visible=True):
363
+ pass
364
 
365
+ with gr.Column(elem_id="main-app-interface", visible=False):
366
+ main_app = main_interface(gr.UserInfo())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  if __name__ == "__main__":
370
  if tagger_instance is None:
371
+ print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。")
372
+ if "SPACE_ID" in os.environ:
373
+ demo.launch()
374
+ else:
375
+ with gr.Blocks() as local_demo:
376
+ fake_user_info = gr.UserInfo(name="local_user", email="local@test.com")
377
+ main_interface(fake_user_info)
378
+ local_demo.launch(server_name="0.0.0.0", server_port=7860)