IdlecloudX commited on
Commit
20d3044
·
verified ·
1 Parent(s): 1ebe87f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -10,12 +10,18 @@ from huggingface_hub import whoami, get_space_runtime
10
 
11
  from translator import translate_texts
12
 
 
 
 
13
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
14
  MODEL_FILENAME = "model.onnx"
15
  LABEL_FILENAME = "selected_tags.csv"
16
 
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
 
 
 
19
  class Tagger:
20
  def __init__(self):
21
  self.hf_token = HF_TOKEN
@@ -48,6 +54,8 @@ class Tagger:
48
  print(f"❌ 模型或标签加载失败: {e}")
49
  raise RuntimeError(f"模型初始化失败: {e}")
50
 
 
 
51
  def _preprocess(self, img: Image.Image) -> np.ndarray:
52
  if img is None: raise ValueError("输入图像不能为空")
53
  if img.mode != "RGB": img = img.convert("RGB")
@@ -58,6 +66,7 @@ class Tagger:
58
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
59
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
60
 
 
61
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
62
  if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
63
  inp_name = self.model.get_inputs()[0].name
@@ -86,12 +95,16 @@ class Tagger:
86
 
87
  return res, tag_categories_for_translation
88
 
 
89
  try:
90
  tagger_instance = Tagger()
91
  except RuntimeError as e:
92
  print(f"应用启动时Tagger初始化失败: {e}")
93
  tagger_instance = None
94
 
 
 
 
95
  custom_css = """
96
  .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
97
  .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
@@ -171,8 +184,15 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
171
  gr.Markdown("### 标签汇总结果")
172
  out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
173
 
 
 
 
 
 
 
 
174
  def check_user_status(request: gr.Request):
175
- token = request.token
176
  if token:
177
  try:
178
  user_info = whoami(token=token)
@@ -225,7 +245,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
225
 
226
  yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
227
 
228
- token, is_owner = request.token, False
 
229
  if token:
230
  try:
231
  user_info = whoami(token=token)
@@ -300,4 +321,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
300
  if __name__ == "__main__":
301
  if tagger_instance is None:
302
  print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
303
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
10
 
11
  from translator import translate_texts
12
 
13
+ # ------------------------------------------------------------------
14
+ # 模型配置
15
+ # ------------------------------------------------------------------
16
  MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
17
  MODEL_FILENAME = "model.onnx"
18
  LABEL_FILENAME = "selected_tags.csv"
19
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
 
22
+ # ------------------------------------------------------------------
23
+ # Tagger 类 (全局实例化)
24
+ # ------------------------------------------------------------------
25
  class Tagger:
26
  def __init__(self):
27
  self.hf_token = HF_TOKEN
 
54
  print(f"❌ 模型或标签加载失败: {e}")
55
  raise RuntimeError(f"模型初始化失败: {e}")
56
 
57
+
58
+ # ------------------------- preprocess -------------------------
59
  def _preprocess(self, img: Image.Image) -> np.ndarray:
60
  if img is None: raise ValueError("输入图像不能为空")
61
  if img.mode != "RGB": img = img.convert("RGB")
 
66
  canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
67
  return np.array(canvas)[:, :, ::-1].astype(np.float32)
68
 
69
+ # --------------------------- predict --------------------------
70
  def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
71
  if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
72
  inp_name = self.model.get_inputs()[0].name
 
95
 
96
  return res, tag_categories_for_translation
97
 
98
+ # 全局 Tagger 实例
99
  try:
100
  tagger_instance = Tagger()
101
  except RuntimeError as e:
102
  print(f"应用启动时Tagger初始化失败: {e}")
103
  tagger_instance = None
104
 
105
+ # ------------------------------------------------------------------
106
+ # Gradio UI
107
+ # ------------------------------------------------------------------
108
  custom_css = """
109
  .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
110
  .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
 
184
  gr.Markdown("### 标签汇总结果")
185
  out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
186
 
187
+ def get_token_from_request(request: gr.Request) -> str | None:
188
+ """Helper function to extract token from request headers."""
189
+ auth_header = request.headers.get("authorization")
190
+ if auth_header and auth_header.startswith("Bearer "):
191
+ return auth_header.split(" ")[1]
192
+ return None
193
+
194
  def check_user_status(request: gr.Request):
195
+ token = get_token_from_request(request)
196
  if token:
197
  try:
198
  user_info = whoami(token=token)
 
245
 
246
  yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
247
 
248
+ token = get_token_from_request(request)
249
+ is_owner = False
250
  if token:
251
  try:
252
  user_info = whoami(token=token)
 
321
  if __name__ == "__main__":
322
  if tagger_instance is None:
323
  print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
324
+ demo.launch(server_name="0.0.0.0", server_port=7860)