Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,12 +10,18 @@ from huggingface_hub import whoami, get_space_runtime
|
|
| 10 |
|
| 11 |
from translator import translate_texts
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
| 14 |
MODEL_FILENAME = "model.onnx"
|
| 15 |
LABEL_FILENAME = "selected_tags.csv"
|
| 16 |
|
| 17 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
class Tagger:
|
| 20 |
def __init__(self):
|
| 21 |
self.hf_token = HF_TOKEN
|
|
@@ -48,6 +54,8 @@ class Tagger:
|
|
| 48 |
print(f"❌ 模型或标签加载失败: {e}")
|
| 49 |
raise RuntimeError(f"模型初始化失败: {e}")
|
| 50 |
|
|
|
|
|
|
|
| 51 |
def _preprocess(self, img: Image.Image) -> np.ndarray:
|
| 52 |
if img is None: raise ValueError("输入图像不能为空")
|
| 53 |
if img.mode != "RGB": img = img.convert("RGB")
|
|
@@ -58,6 +66,7 @@ class Tagger:
|
|
| 58 |
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
|
| 59 |
return np.array(canvas)[:, :, ::-1].astype(np.float32)
|
| 60 |
|
|
|
|
| 61 |
def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
|
| 62 |
if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
|
| 63 |
inp_name = self.model.get_inputs()[0].name
|
|
@@ -86,12 +95,16 @@ class Tagger:
|
|
| 86 |
|
| 87 |
return res, tag_categories_for_translation
|
| 88 |
|
|
|
|
| 89 |
try:
|
| 90 |
tagger_instance = Tagger()
|
| 91 |
except RuntimeError as e:
|
| 92 |
print(f"应用启动时Tagger初始化失败: {e}")
|
| 93 |
tagger_instance = None
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
custom_css = """
|
| 96 |
.label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
|
| 97 |
.tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
|
|
@@ -171,8 +184,15 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
|
|
| 171 |
gr.Markdown("### 标签汇总结果")
|
| 172 |
out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
def check_user_status(request: gr.Request):
|
| 175 |
-
token = request
|
| 176 |
if token:
|
| 177 |
try:
|
| 178 |
user_info = whoami(token=token)
|
|
@@ -225,7 +245,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
|
|
| 225 |
|
| 226 |
yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
|
| 227 |
|
| 228 |
-
token
|
|
|
|
| 229 |
if token:
|
| 230 |
try:
|
| 231 |
user_info = whoami(token=token)
|
|
@@ -300,4 +321,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=cus
|
|
| 300 |
if __name__ == "__main__":
|
| 301 |
if tagger_instance is None:
|
| 302 |
print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
|
| 303 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 10 |
|
| 11 |
from translator import translate_texts
|
| 12 |
|
| 13 |
+
# ------------------------------------------------------------------
|
| 14 |
+
# 模型配置
|
| 15 |
+
# ------------------------------------------------------------------
|
| 16 |
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
| 17 |
MODEL_FILENAME = "model.onnx"
|
| 18 |
LABEL_FILENAME = "selected_tags.csv"
|
| 19 |
|
| 20 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 21 |
|
| 22 |
+
# ------------------------------------------------------------------
|
| 23 |
+
# Tagger 类 (全局实例化)
|
| 24 |
+
# ------------------------------------------------------------------
|
| 25 |
class Tagger:
|
| 26 |
def __init__(self):
|
| 27 |
self.hf_token = HF_TOKEN
|
|
|
|
| 54 |
print(f"❌ 模型或标签加载失败: {e}")
|
| 55 |
raise RuntimeError(f"模型初始化失败: {e}")
|
| 56 |
|
| 57 |
+
|
| 58 |
+
# ------------------------- preprocess -------------------------
|
| 59 |
def _preprocess(self, img: Image.Image) -> np.ndarray:
|
| 60 |
if img is None: raise ValueError("输入图像不能为空")
|
| 61 |
if img.mode != "RGB": img = img.convert("RGB")
|
|
|
|
| 66 |
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
|
| 67 |
return np.array(canvas)[:, :, ::-1].astype(np.float32)
|
| 68 |
|
| 69 |
+
# --------------------------- predict --------------------------
|
| 70 |
def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
|
| 71 |
if self.model is None: raise RuntimeError("模型未成功加载,无法进行预测。")
|
| 72 |
inp_name = self.model.get_inputs()[0].name
|
|
|
|
| 95 |
|
| 96 |
return res, tag_categories_for_translation
|
| 97 |
|
| 98 |
+
# 全局 Tagger 实例
|
| 99 |
try:
|
| 100 |
tagger_instance = Tagger()
|
| 101 |
except RuntimeError as e:
|
| 102 |
print(f"应用启动时Tagger初始化失败: {e}")
|
| 103 |
tagger_instance = None
|
| 104 |
|
| 105 |
+
# ------------------------------------------------------------------
|
| 106 |
+
# Gradio UI
|
| 107 |
+
# ------------------------------------------------------------------
|
| 108 |
custom_css = """
|
| 109 |
.label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
|
| 110 |
.tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
|
|
|
|
| 184 |
gr.Markdown("### 标签汇总结果")
|
| 185 |
out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
|
| 186 |
|
| 187 |
+
def get_token_from_request(request: gr.Request) -> str | None:
|
| 188 |
+
"""Helper function to extract token from request headers."""
|
| 189 |
+
auth_header = request.headers.get("authorization")
|
| 190 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 191 |
+
return auth_header.split(" ")[1]
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
def check_user_status(request: gr.Request):
|
| 195 |
+
token = get_token_from_request(request)
|
| 196 |
if token:
|
| 197 |
try:
|
| 198 |
user_info = whoami(token=token)
|
|
|
|
| 245 |
|
| 246 |
yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
|
| 247 |
|
| 248 |
+
token = get_token_from_request(request)
|
| 249 |
+
is_owner = False
|
| 250 |
if token:
|
| 251 |
try:
|
| 252 |
user_info = whoami(token=token)
|
|
|
|
| 321 |
if __name__ == "__main__":
|
| 322 |
if tagger_instance is None:
|
| 323 |
print("CRITICAL: Tagger 未能初始化,应用功能将受限。")
|
| 324 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|