Spaces:
Running
Running
File size: 16,721 Bytes
393d2d7 d5894b1 393d2d7 d5894b1 01d7dca 6701bf8 fcde2f2 20d3044 01d7dca 20d3044 393d2d7 d5894b1 01d7dca 6701bf8 01d7dca d5894b1 20d3044 01d7dca 20d3044 d5894b1 393d2d7 fcde2f2 393d2d7 01d7dca 393d2d7 01d7dca 20d3044 fcde2f2 01d7dca 6701bf8 393d2d7 fcde2f2 393d2d7 fcde2f2 3553faa fcde2f2 20d3044 393d2d7 01d7dca 393d2d7 fcde2f2 393d2d7 fcde2f2 6701bf8 393d2d7 1ebe87f 01d7dca 393d2d7 01d7dca 3553faa fcde2f2 20d3044 7c7be00 3553faa 7c7be00 393d2d7 da15f0e 3553faa c4ebff6 393d2d7 6701bf8 da15f0e 6701bf8 393d2d7 3553faa 393d2d7 da15f0e cf2d24e 393d2d7 7c7be00 7f24625 01d7dca 6701bf8 01d7dca 393d2d7 fcde2f2 d5894b1 393d2d7 01d7dca cf2d24e 6701bf8 393d2d7 6701bf8 393d2d7 1ebe87f 6701bf8 393d2d7 fcde2f2 1eb8a26 d5894b1 6701bf8 393d2d7 6701bf8 7f38460 20d3044 01d7dca 7f38460 01d7dca 7f38460 6701bf8 20d3044 6701bf8 7f38460 01d7dca 7f38460 6701bf8 01d7dca 6701bf8 3553faa 7c7be00 6701bf8 3553faa 393d2d7 7f24625 6701bf8 393d2d7 3553faa cf2d24e 6701bf8 1ebe87f 6701bf8 393d2d7 6701bf8 393d2d7 6701bf8 20d3044 6701bf8 01d7dca 6701bf8 7c7be00 6701bf8 7c7be00 6701bf8 393d2d7 6701bf8 7c7be00 6701bf8 7c7be00 6701bf8 7c7be00 393d2d7 6701bf8 3553faa 7f38460 6701bf8 fcde2f2 393d2d7 6701bf8 1ebe87f 6701bf8 d5894b1 1ebe87f 393d2d7 6701bf8 393d2d7 d5894b1 393d2d7 01d7dca 7f38460 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 | import os
import json
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from huggingface_hub import whoami, HfApi
from translator import translate_texts
# ------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
# It's recommended to manage the token within the HF Spaces secrets
HF_TOKEN = os.environ.get("HF_TOKEN")
# A more robust way to get the space owner
SPACE_ID = os.environ.get("SPACE_ID")
SPACE_OWNER = SPACE_ID.split('/')[0] if SPACE_ID else None
# ------------------------------------------------------------------
# Tagger Class (Global Instance)
# ------------------------------------------------------------------
class Tagger:
def __init__(self):
self.hf_token = HF_TOKEN
self.tag_names = []
self.categories = {}
self.model = None
self.input_size = 0
self._load_model_and_labels()
def _load_model_and_labels(self):
try:
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO, LABEL_FILENAME, token=self.hf_token, resume_download=True
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO, MODEL_FILENAME, token=self.hf_token, resume_download=True
)
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0],
}
self.model = rt.InferenceSession(model_path)
self.input_size = self.model.get_inputs()[0].shape[1]
print("✅ Model and labels loaded successfully.")
except Exception as e:
print(f"❌ Failed to load model or labels: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# ------------------------- preprocess -------------------------
def _preprocess(self, img: Image.Image) -> np.ndarray:
if img is None: raise ValueError("Input image cannot be None.")
if img.mode != "RGB": img = img.convert("RGB")
size = max(img.size)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
if size != self.input_size:
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
return np.array(canvas)[:, :, ::-1].astype(np.float32)
# --------------------------- predict --------------------------
def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
if self.model is None: raise RuntimeError("Model not loaded, cannot predict.")
inp_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
res = {"ratings": {}, "general": {}, "characters": {}}
tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}
for cat_key, cat_indices in self.categories.items():
sub_res = {}
if cat_key == "rating":
for idx in cat_indices:
tag_name = self.tag_names[idx].replace("_", " ")
sub_res[tag_name] = float(outputs[idx])
else:
threshold = char_th if cat_key == "character" else gen_th
for idx in cat_indices:
if outputs[idx] > threshold:
tag_name = self.tag_names[idx].replace("_", " ")
sub_res[tag_name] = float(outputs[idx])
res_key = "characters" if cat_key == "character" else cat_key
res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
tag_categories_for_translation[res_key] = list(res[res_key].keys())
return res, tag_categories_for_translation
# Global Tagger instance
try:
tagger_instance = Tagger()
except RuntimeError as e:
print(f"Tagger initialization failed on app startup: {e}")
tagger_instance = None
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
custom_css = """
.label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
.tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
.tag-item:hover { background-color: #f0f0f0; }
.tag-en { font-weight: bold; color: #333; cursor: pointer; }
.tag-zh { color: #666; margin-left: 10px; }
.tag-score { color: #999; font-size: 0.9em; }
.btn-analyze-container { margin-top: 15px; margin-bottom: 15px; }
"""
_js_functions = """
function copyToClipboard(text) {
if (typeof text === 'undefined' || text === null) {
console.warn('copyToClipboard was called with undefined or null text.');
return;
}
navigator.clipboard.writeText(text).then(() => {
const feedback = document.createElement('div');
let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : '');
feedback.textContent = '已复制: ' + displayText;
Object.assign(feedback.style, {
position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
backgroundColor: '#4CAF50', color: 'white', padding: '10px 20px',
borderRadius: '5px', zIndex: '10000', transition: 'opacity 0.5s ease-out'
});
document.body.appendChild(feedback);
setTimeout(() => {
feedback.style.opacity = '0';
setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500);
}, 1500);
}).catch(err => {
console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text);
});
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
gr.Markdown("# 🖼️ AI 图像标签分析器")
gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")
with gr.Row():
with gr.Column(scale=1):
login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")
state_res = gr.State({})
state_translations_dict = gr.State({})
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="上传图片", height=300)
btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])
with gr.Accordion("⚙️ 高级设置", open=False):
gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")
with gr.Accordion("🔑 自定义翻译密钥 (可选)", open=False, visible=False) as api_key_accordion:
gr.Markdown("如果你不是空间所有者,需要在这里提供自己的API密钥才能使用翻译功能。")
tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1)
tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password")
baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]')
with gr.Accordion("📊 标签汇总设置", open=True):
sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["通用标签", "角色标签"], label="汇总类别")
sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符")
sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")
processing_info = gr.Markdown("", visible=False)
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags")
with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags")
with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags")
gr.Markdown("### 标签汇总结果")
out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)
def get_token_from_request(request: gr.Request) -> str | None:
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
def is_user_space_owner(user_info: dict | None) -> bool:
"""
Robustly checks if the user is the owner of the space by parsing SPACE_ID.
"""
if not user_info or not SPACE_OWNER:
if not SPACE_OWNER:
print("⚠️ Warning: SPACE_ID environment variable not found.")
return False
user_name = user_info.get("name")
user_orgs = [org.get("name") for org in user_info.get("orgs", [])]
print(f"ℹ️ [Auth Check] Space Owner: '{SPACE_OWNER}', User: '{user_name}', User Orgs: {user_orgs}")
is_owner = (user_name == SPACE_OWNER) or (SPACE_OWNER in user_orgs)
return is_owner
def check_user_status(request: gr.Request):
token = get_token_from_request(request)
if token:
try:
user_info = whoami(token=token)
if is_user_space_owner(user_info):
return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
else:
return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
except Exception as e:
print(f"Error getting user info: {e}")
return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
def format_tags_html(tags_dict, translations_list, show_scores):
if not tags_dict: return "<p>暂无标签</p>"
html = '<div class="label-container">'
for i, (tag, score) in enumerate(tags_dict.items()):
escaped_tag = tag.replace("'", "\\'")
html += '<div class="tag-item">'
tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
if i < len(translations_list) and translations_list[i]:
tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
html += f'<div>{tag_display_html}</div>'
if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
html += '</div>'
return html + '</div>'
def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh):
if not current_res: return "请先分析图像。"
parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ")
cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"}
for cat_name in sum_cats:
cat_key = cat_map.get(cat_name)
if cat_key and current_res.get(cat_key):
tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, [])
tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)]
if tags_to_join: parts.append(sep.join(tags_to_join))
return "\n".join(parts) if parts else "选定的类别中没有找到标签。"
def process_image_and_generate_outputs(
img, g_th, c_th, s_scores,
user_tencent_id, user_tencent_key, user_baidu_json,
sum_cats, s_sep, s_zh_in_sum,
request: gr.Request
):
if img is None:
raise gr.Error("请先上传图片。")
if tagger_instance is None:
raise gr.Error("分析器未成功初始化,请检查后台错误。")
yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}
token = get_token_from_request(request)
is_owner = False
if token:
try:
user_info = whoami(token=token)
if is_user_space_owner(user_info):
is_owner = True
except Exception: pass
final_tencent_id, final_tencent_key, baidu_json_str = (
(os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]"))
if is_owner else (user_tencent_id, user_tencent_key, user_baidu_json)
)
final_baidu_creds_list = []
if baidu_json_str and baidu_json_str.strip():
try:
parsed_data = json.loads(baidu_json_str)
if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data
except json.JSONDecodeError: print("提供的百度凭证JSON无效。")
try:
res, tag_cats_original = tagger_instance.predict(img, g_th, c_th)
all_tags = [tag for cat in tag_cats_original.values() for tag in cat]
translations_flat = translate_texts(
all_tags,
tencent_secret_id=final_tencent_id,
tencent_secret_key=final_tencent_key,
baidu_credentials_list=final_baidu_creds_list
) if all_tags else []
translations, offset = {}, 0
for cat_key, tags in tag_cats_original.items():
translations[cat_key] = translations_flat[offset : offset + len(tags)]
offset += len(tags)
outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]}
summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum)
yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations
except Exception as e:
import traceback
traceback.print_exc()
raise gr.Error(f"处理时发生错误: {e}")
demo.load(fn=check_user_status, inputs=None, outputs=[user_status_md, api_key_accordion], queue=False)
btn.click(
process_image_and_generate_outputs,
inputs=[
img_in, gen_slider, char_slider, show_tag_scores,
tencent_id_in, tencent_key_in, baidu_json_in,
sum_cats, sum_sep, sum_show_zh
],
outputs=[
btn, processing_info,
out_general, out_char, out_rating,
out_summary,
state_res, state_translations_dict
],
)
summary_controls = [sum_cats, sum_sep, sum_show_zh]
for ctrl in summary_controls:
ctrl.change(
fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z),
inputs=[state_res, state_translations_dict] + summary_controls,
outputs=[out_summary],
)
if __name__ == "__main__":
if tagger_instance is None:
print("CRITICAL: Tagger failed to initialize. App functionality will be limited.")
demo.launch(server_name="0.0.0.0", server_port=7860) |