Spaces:
Running
Running
| import os, json | |
| import gradio as gr | |
| import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd | |
| from PIL import Image | |
| from huggingface_hub import login | |
| from translator import translate_texts | |
| # ------------------------------------------------------------------ | |
| # 模型配置 | |
| # ------------------------------------------------------------------ | |
| MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| else: | |
| print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") | |
| # ------------------------------------------------------------------ | |
| # Tagger 类 | |
| # ------------------------------------------------------------------ | |
| class Tagger: | |
| def __init__(self): | |
| self.hf_token = HF_TOKEN | |
| self._load_model_and_labels() | |
| def _load_model_and_labels(self): | |
| label_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, LABEL_FILENAME, token=self.hf_token | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, MODEL_FILENAME, token=self.hf_token | |
| ) | |
| tags_df = pd.read_csv(label_path) | |
| self.tag_names = tags_df["name"].tolist() | |
| self.categories = { | |
| "rating": np.where(tags_df["category"] == 9)[0], | |
| "general": np.where(tags_df["category"] == 0)[0], | |
| "character": np.where(tags_df["category"] == 4)[0], | |
| } | |
| self.model = rt.InferenceSession(model_path) | |
| self.input_size = self.model.get_inputs()[0].shape[1] | |
| # ------------------------- preprocess ------------------------- | |
| def _preprocess(self, img: Image.Image) -> np.ndarray: | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| size = max(img.size) | |
| canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
| canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) | |
| if size != self.input_size: | |
| canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
| return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR | |
| # --------------------------- predict -------------------------- | |
| def predict(self, img: Image.Image, | |
| gen_th: float = 0.35, | |
| char_th: float = 0.85): | |
| inp_name = self.model.get_inputs()[0].name | |
| outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
| res = {"ratings": {}, "general": {}, "characters": {}} | |
| for idx in self.categories["rating"]: | |
| res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
| for idx in self.categories["general"]: | |
| if outputs[idx] > gen_th: | |
| res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
| for idx in self.categories["character"]: | |
| if outputs[idx] > char_th: | |
| res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
| res["general"] = dict(sorted(res["general"].items(), | |
| key=lambda kv: kv[1], | |
| reverse=True)) | |
| return res | |
| # ------------------------------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------------------------------ | |
| custom_css = """ | |
| .label-container { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| border: 1px solid #ddd; | |
| padding: 10px; | |
| border-radius: 5px; | |
| background-color: #f9f9f9; | |
| } | |
| .tag-item { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin: 2px 0; | |
| padding: 2px 5px; | |
| border-radius: 3px; | |
| background-color: #fff; | |
| } | |
| .tag-en { | |
| font-weight: bold; | |
| color: #333; | |
| } | |
| .tag-zh { | |
| color: #666; | |
| margin-left: 10px; | |
| } | |
| .tag-score { | |
| color: #999; | |
| font-size: 0.9em; | |
| } | |
| .btn-container { | |
| margin-top: 20px; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo: | |
| gr.Markdown("# 🖼️ AI 图像标签分析器") | |
| gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="上传图片") | |
| with gr.Accordion("⚙️ 高级设置", open=True): | |
| gen_slider = gr.Slider(0, 1, 0.35, | |
| label="通用标签阈值", info="越高→标签更少更准") | |
| char_slider = gr.Slider(0, 1, 0.85, | |
| label="角色标签阈值", info="推荐保持较高阈值") | |
| show_zh = gr.Checkbox(True, label="显示中文翻译") | |
| gr.Markdown("### 汇总设置") | |
| with gr.Row(): | |
| sum_general = gr.Checkbox(True, label="通用标签") | |
| sum_char = gr.Checkbox(True, label="角色标签") | |
| sum_rating = gr.Checkbox(False, label="评分标签") | |
| sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符") | |
| btn = gr.Button("开始分析", variant="primary", elem_classes=["btn-container"]) | |
| processing_info = gr.Markdown("", visible=False) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("🏷️ 通用标签"): | |
| out_general = gr.HTML(label="General Tags") | |
| with gr.TabItem("👤 角色标签"): | |
| out_char = gr.HTML(label="Character Tags") | |
| with gr.TabItem("⭐ 评分标签"): | |
| out_rating = gr.HTML(label="Rating Tags") | |
| gr.Markdown("### 标签汇总") | |
| out_summary = gr.Textbox(label="标签汇总", | |
| placeholder="选择需要汇总的标签类别...", | |
| lines=3) | |
| # ----------------- 处理回调 ----------------- | |
| def format_tags_html(tags_dict, translations, show_translation=True): | |
| """格式化标签为HTML格式""" | |
| if not tags_dict: | |
| return "<p>暂无标签</p>" | |
| html = '<div class="label-container">' | |
| for i, (tag, score) in enumerate(tags_dict.items()): | |
| tag_html = f'<div class="tag-item">' | |
| tag_html += f'<div><span class="tag-en">{tag}</span>' | |
| if show_translation and i < len(translations): | |
| tag_html += f'<span class="tag-zh">({translations[i]})</span>' | |
| tag_html += '</div>' | |
| tag_html += f'<span class="tag-score">{score:.3f}</span>' | |
| tag_html += '</div>' | |
| html += tag_html | |
| html += '</div>' | |
| return html | |
| def process(img, g_th, c_th, show_zh, sum_gen, sum_char, sum_rat, sep_type): | |
| # 开始处理,返回更新 | |
| yield ( | |
| gr.update(interactive=False, value="处理中..."), | |
| gr.update(visible=True, value="🔄 正在分析图像..."), | |
| "", "", "", "" | |
| ) | |
| try: | |
| tagger = Tagger() | |
| res = tagger.predict(img, g_th, c_th) | |
| # 收集所有需要翻译的标签 | |
| all_tags = [] | |
| tag_categories = { | |
| "general": list(res["general"].keys()), | |
| "characters": list(res["characters"].keys()), | |
| "ratings": list(res["ratings"].keys()) | |
| } | |
| if show_zh: | |
| for tags in tag_categories.values(): | |
| all_tags.extend(tags) | |
| # 批量翻译 | |
| if all_tags: | |
| translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh") | |
| else: | |
| translations = [] | |
| else: | |
| translations = [] | |
| # 分配翻译结果 | |
| translations_dict = {} | |
| offset = 0 | |
| for category, tags in tag_categories.items(): | |
| if show_zh and tags: | |
| translations_dict[category] = translations[offset:offset+len(tags)] | |
| offset += len(tags) | |
| else: | |
| translations_dict[category] = [] | |
| # 生成HTML输出 | |
| general_html = format_tags_html(res["general"], translations_dict["general"], show_zh) | |
| char_html = format_tags_html(res["characters"], translations_dict["characters"], show_zh) | |
| rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], show_zh) | |
| # 生成汇总文本 | |
| summary_parts = [] | |
| separators = {"逗号": ", ", "换行": "\n", "空格": " "} | |
| separator = separators[sep_type] | |
| if sum_gen and res["general"]: | |
| if show_zh and translations_dict["general"]: | |
| gen_tags = [f"{en}({zh})" for en, zh in zip(res["general"].keys(), translations_dict["general"])] | |
| else: | |
| gen_tags = list(res["general"].keys()) | |
| summary_parts.append("通用标签: " + separator.join(gen_tags)) | |
| if sum_char and res["characters"]: | |
| if show_zh and translations_dict["characters"]: | |
| char_tags = [f"{en}({zh})" for en, zh in zip(res["characters"].keys(), translations_dict["characters"])] | |
| else: | |
| char_tags = list(res["characters"].keys()) | |
| summary_parts.append("角色标签: " + separator.join(char_tags)) | |
| if sum_rat and res["ratings"]: | |
| if show_zh and translations_dict["ratings"]: | |
| rat_tags = [f"{en}({zh})" for en, zh in zip(res["ratings"].keys(), translations_dict["ratings"])] | |
| else: | |
| rat_tags = list(res["ratings"].keys()) | |
| summary_parts.append("评分标签: " + separator.join(rat_tags)) | |
| summary_text = "\n\n".join(summary_parts) if summary_parts else "请选择要汇总的标签类别" | |
| # 完成处理,返回最终结果 | |
| yield ( | |
| gr.update(interactive=True, value="开始分析"), | |
| gr.update(visible=False), | |
| general_html, | |
| char_html, | |
| rating_html, | |
| summary_text | |
| ) | |
| except Exception as e: | |
| # 出错时的处理 | |
| yield ( | |
| gr.update(interactive=True, value="开始分析"), | |
| gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), | |
| "", "", "", "" | |
| ) | |
| # 绑定事件 | |
| btn.click( | |
| process, | |
| inputs=[img_in, gen_slider, char_slider, show_zh, sum_general, sum_char, sum_rating, sum_sep], | |
| outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary], | |
| show_progress=True | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 启动 | |
| # ------------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |