IdlecloudX commited on
Commit
01d7dca
·
verified ·
1 Parent(s): 7f38460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -6,22 +6,25 @@ import numpy as np
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
- from huggingface_hub import whoami, get_space_runtime, HfApi
10
- from huggingface_hub.hf_api import SpaceRuntime
11
-
12
  from translator import translate_texts
13
 
14
  # ------------------------------------------------------------------
15
- # 模型配置
16
  # ------------------------------------------------------------------
17
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
18
  MODEL_FILENAME = "model.onnx"
19
  LABEL_FILENAME = "selected_tags.csv"
20
 
 
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
22
 
23
  # ------------------------------------------------------------------
24
- # Tagger (全局实例化)
25
  # ------------------------------------------------------------------
26
  class Tagger:
27
  def __init__(self):
@@ -50,15 +53,14 @@ class Tagger:
50
  }
51
  self.model = rt.InferenceSession(model_path)
52
  self.input_size = self.model.get_inputs()[0].shape[1]
53
- print("✅ 模型和标签加载成功")
54
  except Exception as e:
55
- print(f"❌ 模型或标签加载失败: {e}")
56
- raise RuntimeError(f"模型初始化失败: {e}")
57
-
58
 
59
  # ------------------------- preprocess -------------------------
60
  def _preprocess(self, img: Image.Image) -> np.ndarray:
61
- if img is None: raise ValueError("输入图像不能为空")
62
  if img.mode != "RGB": img = img.convert("RGB")
63
  size = max(img.size)
64
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
@@ -69,7 +71,7 @@ class Tagger:
69
 
70
  # --------------------------- predict --------------------------
71
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
72
- if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
73
  inp_name = self.model.get_inputs()[0].name
74
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
75
 
@@ -89,18 +91,17 @@ class Tagger:
89
  tag_name = self.tag_names[idx].replace("_", " ")
90
  sub_res[tag_name] = float(outputs[idx])
91
 
92
- # Use the correct key for 'character'
93
  res_key = "characters" if cat_key == "character" else cat_key
94
  res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
95
  tag_categories_for_translation[res_key] = list(res[res_key].keys())
96
 
97
  return res, tag_categories_for_translation
98
 
99
- # 全局 Tagger 实例
100
  try:
101
  tagger_instance = Tagger()
102
  except RuntimeError as e:
103
- print(f"应用启动时Tagger初始化失败: {e}")
104
  tagger_instance = None
105
 
106
  # ------------------------------------------------------------------
@@ -145,12 +146,12 @@ function copyToClipboard(text) {
145
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
146
  gr.Markdown("# 🖼️ AI 图像标签分析器")
147
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
148
-
149
  with gr.Row():
150
  with gr.Column(scale=1):
151
  login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
152
  user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")
153
-
154
  state_res = gr.State({})
155
  state_translations_dict = gr.State({})
156
 
@@ -158,7 +159,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
158
  with gr.Column(scale=1):
159
  img_in = gr.Image(type="pil", label="上传图片", height=300)
160
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
161
-
162
  with gr.Accordion("⚙️ 高级设置", open=False):
163
  gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
164
  char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
@@ -191,30 +192,35 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
191
  return auth_header.split(" ")[1]
192
  return None
193
 
194
- def check_user_is_owner(user_info: dict | None, space_runtime: SpaceRuntime | None) -> bool:
195
- if not user_info or not space_runtime:
 
 
 
 
 
196
  return False
197
- if user_info.get("name") == space_runtime.owner:
198
- return True
199
- user_orgs = user_info.get("orgs", [])
200
- if any(org.get("name") == space_runtime.owner for org in user_orgs):
201
- return True
202
-
203
- return False
 
204
 
205
  def check_user_status(request: gr.Request):
206
  token = get_token_from_request(request)
207
  if token:
208
  try:
209
  user_info = whoami(token=token)
210
- space_runtime = get_space_runtime()
211
 
212
- if check_user_is_owner(user_info, space_runtime):
213
  return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
214
  else:
215
  return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
216
  except Exception as e:
217
- print(f"获取用户信息时出错: {e}")
218
  return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
219
  return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
220
 
@@ -262,8 +268,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
262
  if token:
263
  try:
264
  user_info = whoami(token=token)
265
- space_runtime = get_space_runtime()
266
- if check_user_is_owner(user_info, space_runtime):
267
  is_owner = True
268
  except Exception: pass
269
 
@@ -332,5 +337,5 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
332
 
333
  if __name__ == "__main__":
334
  if tagger_instance is None:
335
- print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
336
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
6
  import onnxruntime as rt
7
  import pandas as pd
8
  from PIL import Image
9
+ from huggingface_hub import whoami, HfApi
 
 
10
  from translator import translate_texts
11
 
12
  # ------------------------------------------------------------------
13
+ # Model Configuration
14
  # ------------------------------------------------------------------
15
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
16
  MODEL_FILENAME = "model.onnx"
17
  LABEL_FILENAME = "selected_tags.csv"
18
 
19
+ # It's recommended to manage the token within the HF Spaces secrets
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ # A more robust way to get the space owner
22
+ SPACE_ID = os.environ.get("SPACE_ID")
23
+ SPACE_OWNER = SPACE_ID.split('/')[0] if SPACE_ID else None
24
+
25
 
26
  # ------------------------------------------------------------------
27
+ # Tagger Class (Global Instance)
28
  # ------------------------------------------------------------------
29
  class Tagger:
30
  def __init__(self):
 
53
  }
54
  self.model = rt.InferenceSession(model_path)
55
  self.input_size = self.model.get_inputs()[0].shape[1]
56
+ print("✅ Model and labels loaded successfully.")
57
  except Exception as e:
58
+ print(f"❌ Failed to load model or labels: {e}")
59
+ raise RuntimeError(f"Model initialization failed: {e}")
 
60
 
61
  # ------------------------- preprocess -------------------------
62
  def _preprocess(self, img: Image.Image) -> np.ndarray:
63
+ if img is None: raise ValueError("Input image cannot be None.")
64
  if img.mode != "RGB": img = img.convert("RGB")
65
  size = max(img.size)
66
  canvas = Image.new("RGB", (size, size), (255, 255, 255))
 
71
 
72
  # --------------------------- predict --------------------------
73
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
74
+ if self.model is None: raise RuntimeError("Model not loaded, cannot predict.")
75
  inp_name = self.model.get_inputs()[0].name
76
  outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
77
 
 
91
  tag_name = self.tag_names[idx].replace("_", " ")
92
  sub_res[tag_name] = float(outputs[idx])
93
 
 
94
  res_key = "characters" if cat_key == "character" else cat_key
95
  res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
96
  tag_categories_for_translation[res_key] = list(res[res_key].keys())
97
 
98
  return res, tag_categories_for_translation
99
 
100
+ # Global Tagger instance
101
  try:
102
  tagger_instance = Tagger()
103
  except RuntimeError as e:
104
+ print(f"Tagger initialization failed on app startup: {e}")
105
  tagger_instance = None
106
 
107
  # ------------------------------------------------------------------
 
146
  with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
147
  gr.Markdown("# 🖼️ AI 图像标签分析器")
148
  gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
149
+
150
  with gr.Row():
151
  with gr.Column(scale=1):
152
  login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
153
  user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")
154
+
155
  state_res = gr.State({})
156
  state_translations_dict = gr.State({})
157
 
 
159
  with gr.Column(scale=1):
160
  img_in = gr.Image(type="pil", label="上传图片", height=300)
161
  btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
162
+
163
  with gr.Accordion("⚙️ 高级设置", open=False):
164
  gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
165
  char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
 
192
  return auth_header.split(" ")[1]
193
  return None
194
 
195
+ def is_user_space_owner(user_info: dict | None) -> bool:
196
+ """
197
+ Robustly checks if the user is the owner of the space by parsing SPACE_ID.
198
+ """
199
+ if not user_info or not SPACE_OWNER:
200
+ if not SPACE_OWNER:
201
+ print("⚠️ Warning: SPACE_ID environment variable not found.")
202
  return False
203
+
204
+ user_name = user_info.get("name")
205
+ user_orgs = [org.get("name") for org in user_info.get("orgs", [])]
206
+
207
+ print(f"ℹ️ [Auth Check] Space Owner: '{SPACE_OWNER}', User: '{user_name}', User Orgs: {user_orgs}")
208
+
209
+ is_owner = (user_name == SPACE_OWNER) or (SPACE_OWNER in user_orgs)
210
+ return is_owner
211
 
212
  def check_user_status(request: gr.Request):
213
  token = get_token_from_request(request)
214
  if token:
215
  try:
216
  user_info = whoami(token=token)
 
217
 
218
+ if is_user_space_owner(user_info):
219
  return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
220
  else:
221
  return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
222
  except Exception as e:
223
+ print(f"Error getting user info: {e}")
224
  return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
225
  return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
226
 
 
268
  if token:
269
  try:
270
  user_info = whoami(token=token)
271
+ if is_user_space_owner(user_info):
 
272
  is_owner = True
273
  except Exception: pass
274
 
 
337
 
338
  if __name__ == "__main__":
339
  if tagger_instance is None:
340
+ print("CRITICAL: Tagger failed to initialize. App functionality will be limited.")
341
  demo.launch(server_name="0.0.0.0", server_port=7860)