File size: 16,721 Bytes
393d2d7
 
d5894b1
393d2d7
 
 
 
d5894b1
01d7dca
6701bf8
fcde2f2
20d3044
01d7dca
20d3044
393d2d7
 
 
d5894b1
01d7dca
6701bf8
01d7dca
 
 
 
d5894b1
20d3044
01d7dca
20d3044
d5894b1
 
393d2d7
 
 
 
 
fcde2f2
 
 
393d2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d7dca
393d2d7
01d7dca
 
20d3044
 
fcde2f2
01d7dca
6701bf8
393d2d7
fcde2f2
393d2d7
fcde2f2
 
3553faa
fcde2f2
20d3044
393d2d7
01d7dca
393d2d7
 
fcde2f2
 
393d2d7
fcde2f2
6701bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393d2d7
 
1ebe87f
01d7dca
393d2d7
 
 
01d7dca
3553faa
fcde2f2
20d3044
 
 
7c7be00
3553faa
 
 
 
 
 
 
7c7be00
 
393d2d7
 
da15f0e
3553faa
c4ebff6
 
393d2d7
 
6701bf8
da15f0e
6701bf8
 
 
 
 
393d2d7
 
 
3553faa
393d2d7
 
da15f0e
cf2d24e
 
 
 
393d2d7
7c7be00
7f24625
01d7dca
6701bf8
 
 
 
01d7dca
393d2d7
 
fcde2f2
d5894b1
 
393d2d7
 
01d7dca
cf2d24e
6701bf8
 
393d2d7
6701bf8
 
 
 
 
 
 
393d2d7
1ebe87f
6701bf8
393d2d7
fcde2f2
1eb8a26
 
d5894b1
 
6701bf8
 
 
393d2d7
6701bf8
7f38460
20d3044
 
 
 
 
 
01d7dca
 
 
 
 
 
 
7f38460
01d7dca
 
 
 
 
 
 
 
7f38460
6701bf8
20d3044
6701bf8
 
 
7f38460
01d7dca
7f38460
6701bf8
 
 
01d7dca
6701bf8
 
 
 
3553faa
7c7be00
6701bf8
3553faa
393d2d7
7f24625
6701bf8
393d2d7
 
3553faa
cf2d24e
6701bf8
 
 
 
 
 
 
1ebe87f
 
6701bf8
 
 
 
 
 
 
 
 
 
 
393d2d7
6701bf8
393d2d7
6701bf8
 
 
 
20d3044
 
6701bf8
 
 
01d7dca
6701bf8
 
7c7be00
6701bf8
 
 
 
 
 
 
 
 
 
 
 
7c7be00
6701bf8
 
393d2d7
6701bf8
 
 
 
 
 
 
 
 
 
 
7c7be00
6701bf8
 
7c7be00
6701bf8
7c7be00
 
393d2d7
6701bf8
 
3553faa
7f38460
6701bf8
fcde2f2
393d2d7
6701bf8
 
 
1ebe87f
6701bf8
 
 
 
 
 
 
d5894b1
 
1ebe87f
393d2d7
6701bf8
 
 
 
 
393d2d7
d5894b1
393d2d7
01d7dca
7f38460
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import os
import json
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from huggingface_hub import whoami, HfApi
from translator import translate_texts

# ------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

# It's recommended to manage the token within the HF Spaces secrets
HF_TOKEN = os.environ.get("HF_TOKEN")
# A more robust way to get the space owner
SPACE_ID = os.environ.get("SPACE_ID")
SPACE_OWNER = SPACE_ID.split('/')[0] if SPACE_ID else None


# ------------------------------------------------------------------
# Tagger Class (Global Instance)
# ------------------------------------------------------------------
class Tagger:
    def __init__(self):
        self.hf_token = HF_TOKEN
        self.tag_names = []
        self.categories = {}
        self.model = None
        self.input_size = 0
        self._load_model_and_labels()

    def _load_model_and_labels(self):
        try:
            label_path = huggingface_hub.hf_hub_download(
                MODEL_REPO, LABEL_FILENAME, token=self.hf_token, resume_download=True
            )
            model_path = huggingface_hub.hf_hub_download(
                MODEL_REPO, MODEL_FILENAME, token=self.hf_token, resume_download=True
            )

            tags_df = pd.read_csv(label_path)
            self.tag_names = tags_df["name"].tolist()
            self.categories = {
                "rating": np.where(tags_df["category"] == 9)[0],
                "general": np.where(tags_df["category"] == 0)[0],
                "character": np.where(tags_df["category"] == 4)[0],
            }
            self.model = rt.InferenceSession(model_path)
            self.input_size = self.model.get_inputs()[0].shape[1]
            print("✅ Model and labels loaded successfully.")
        except Exception as e:
            print(f"❌ Failed to load model or labels: {e}")
            raise RuntimeError(f"Model initialization failed: {e}")

    # ------------------------- preprocess -------------------------
    def _preprocess(self, img: Image.Image) -> np.ndarray:
        if img is None: raise ValueError("Input image cannot be None.")
        if img.mode != "RGB": img = img.convert("RGB")
        size = max(img.size)
        canvas = Image.new("RGB", (size, size), (255, 255, 255))
        canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2))
        if size != self.input_size:
            canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
        return np.array(canvas)[:, :, ::-1].astype(np.float32)

    # --------------------------- predict --------------------------
    def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85):
        if self.model is None: raise RuntimeError("Model not loaded, cannot predict.")
        inp_name = self.model.get_inputs()[0].name
        outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]

        res = {"ratings": {}, "general": {}, "characters": {}}
        tag_categories_for_translation = {"ratings": [], "general": [], "characters": []}

        for cat_key, cat_indices in self.categories.items():
            sub_res = {}
            if cat_key == "rating":
                for idx in cat_indices:
                    tag_name = self.tag_names[idx].replace("_", " ")
                    sub_res[tag_name] = float(outputs[idx])
            else:
                threshold = char_th if cat_key == "character" else gen_th
                for idx in cat_indices:
                    if outputs[idx] > threshold:
                        tag_name = self.tag_names[idx].replace("_", " ")
                        sub_res[tag_name] = float(outputs[idx])

            res_key = "characters" if cat_key == "character" else cat_key
            res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True))
            tag_categories_for_translation[res_key] = list(res[res_key].keys())

        return res, tag_categories_for_translation

# Global Tagger instance
try:
    tagger_instance = Tagger()
except RuntimeError as e:
    print(f"Tagger initialization failed on app startup: {e}")
    tagger_instance = None

# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
custom_css = """
.label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; }
.tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; }
.tag-item:hover { background-color: #f0f0f0; }
.tag-en { font-weight: bold; color: #333; cursor: pointer; }
.tag-zh { color: #666; margin-left: 10px; }
.tag-score { color: #999; font-size: 0.9em; }
.btn-analyze-container { margin-top: 15px; margin-bottom: 15px; }
"""

_js_functions = """
function copyToClipboard(text) {
    if (typeof text === 'undefined' || text === null) {
        console.warn('copyToClipboard was called with undefined or null text.');
        return;
    }
    navigator.clipboard.writeText(text).then(() => {
        const feedback = document.createElement('div');
        let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : '');
        feedback.textContent = '已复制: ' + displayText;
        Object.assign(feedback.style, {
            position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)',
            backgroundColor: '#4CAF50', color: 'white', padding: '10px 20px',
            borderRadius: '5px', zIndex: '10000', transition: 'opacity 0.5s ease-out'
        });
        document.body.appendChild(feedback);
        setTimeout(() => {
            feedback.style.opacity = '0';
            setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500);
        }, 1500);
    }).catch(err => {
        console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text);
    });
}
"""

with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo:
    gr.Markdown("# 🖼️ AI 图像标签分析器")
    gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)")

    with gr.Row():
        with gr.Column(scale=1):
            login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录")
            user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...")

    state_res = gr.State({})
    state_translations_dict = gr.State({})

    with gr.Row():
        with gr.Column(scale=1):
            img_in = gr.Image(type="pil", label="上传图片", height=300)
            btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"])

            with gr.Accordion("⚙️ 高级设置", open=False):
                gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值")
                char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值")
                show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度")

            with gr.Accordion("🔑 自定义翻译密钥 (可选)", open=False, visible=False) as api_key_accordion:
                gr.Markdown("如果你不是空间所有者,需要在这里提供自己的API密钥才能使用翻译功能。")
                tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1)
                tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password")
                baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]')

            with gr.Accordion("📊 标签汇总设置", open=True):
                sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["通用标签", "角色标签"], label="汇总类别")
                sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符")
                sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译")

            processing_info = gr.Markdown("", visible=False)

        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags")
                with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags")
                with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags")
            gr.Markdown("### 标签汇总结果")
            out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True)

    def get_token_from_request(request: gr.Request) -> str | None:
        auth_header = request.headers.get("authorization")
        if auth_header and auth_header.startswith("Bearer "):
            return auth_header.split(" ")[1]
        return None

    def is_user_space_owner(user_info: dict | None) -> bool:
        """
        Robustly checks if the user is the owner of the space by parsing SPACE_ID.
        """
        if not user_info or not SPACE_OWNER:
            if not SPACE_OWNER:
                print("⚠️ Warning: SPACE_ID environment variable not found.")
            return False
        
        user_name = user_info.get("name")
        user_orgs = [org.get("name") for org in user_info.get("orgs", [])]

        print(f"ℹ️ [Auth Check] Space Owner: '{SPACE_OWNER}', User: '{user_name}', User Orgs: {user_orgs}")

        is_owner = (user_name == SPACE_OWNER) or (SPACE_OWNER in user_orgs)
        return is_owner

    def check_user_status(request: gr.Request):
        token = get_token_from_request(request)
        if token:
            try:
                user_info = whoami(token=token)
                
                if is_user_space_owner(user_info):
                    return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False)
                else:
                    return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True)
            except Exception as e:
                print(f"Error getting user info: {e}")
                return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True)
        return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True)
        
    def format_tags_html(tags_dict, translations_list, show_scores):
        if not tags_dict: return "<p>暂无标签</p>"
        html = '<div class="label-container">'
        for i, (tag, score) in enumerate(tags_dict.items()):
            escaped_tag = tag.replace("'", "\\'")
            html += '<div class="tag-item">'
            tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>'
            if i < len(translations_list) and translations_list[i]:
                tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>'
            html += f'<div>{tag_display_html}</div>'
            if show_scores: html += f'<span class="tag-score">{score:.3f}</span>'
            html += '</div>'
        return html + '</div>'

    def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh):
        if not current_res: return "请先分析图像。"
        parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ")
        cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"}
        for cat_name in sum_cats:
            cat_key = cat_map.get(cat_name)
            if cat_key and current_res.get(cat_key):
                tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, [])
                tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)]
                if tags_to_join: parts.append(sep.join(tags_to_join))
        return "\n".join(parts) if parts else "选定的类别中没有找到标签。"

    def process_image_and_generate_outputs(
        img, g_th, c_th, s_scores,
        user_tencent_id, user_tencent_key, user_baidu_json,
        sum_cats, s_sep, s_zh_in_sum,
        request: gr.Request
    ):
        if img is None:
            raise gr.Error("请先上传图片。")
        if tagger_instance is None:
            raise gr.Error("分析器未成功初始化,请检查后台错误。")

        yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {}

        token = get_token_from_request(request)
        is_owner = False
        if token:
            try:
                user_info = whoami(token=token)
                if is_user_space_owner(user_info):
                    is_owner = True
            except Exception: pass
        
        final_tencent_id, final_tencent_key, baidu_json_str = (
            (os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]"))
            if is_owner else (user_tencent_id, user_tencent_key, user_baidu_json)
        )
        
        final_baidu_creds_list = []
        if baidu_json_str and baidu_json_str.strip():
            try:
                parsed_data = json.loads(baidu_json_str)
                if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data
            except json.JSONDecodeError: print("提供的百度凭证JSON无效。")

        try:
            res, tag_cats_original = tagger_instance.predict(img, g_th, c_th)
            all_tags = [tag for cat in tag_cats_original.values() for tag in cat]
            
            translations_flat = translate_texts(
                all_tags,
                tencent_secret_id=final_tencent_id,
                tencent_secret_key=final_tencent_key,
                baidu_credentials_list=final_baidu_creds_list
            ) if all_tags else []
            
            translations, offset = {}, 0
            for cat_key, tags in tag_cats_original.items():
                translations[cat_key] = translations_flat[offset : offset + len(tags)]
                offset += len(tags)
            
            outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]}
            summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum)

            yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            raise gr.Error(f"处理时发生错误: {e}")

    demo.load(fn=check_user_status, inputs=None, outputs=[user_status_md, api_key_accordion], queue=False)
    
    btn.click(
        process_image_and_generate_outputs,
        inputs=[
            img_in, gen_slider, char_slider, show_tag_scores,
            tencent_id_in, tencent_key_in, baidu_json_in,
            sum_cats, sum_sep, sum_show_zh
        ],
        outputs=[
            btn, processing_info,
            out_general, out_char, out_rating,
            out_summary,
            state_res, state_translations_dict
        ],
    )

    summary_controls = [sum_cats, sum_sep, sum_show_zh]
    for ctrl in summary_controls:
        ctrl.change(
            fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z),
            inputs=[state_res, state_translations_dict] + summary_controls,
            outputs=[out_summary],
        )
    
if __name__ == "__main__":
    if tagger_instance is None:
        print("CRITICAL: Tagger failed to initialize. App functionality will be limited.")
    demo.launch(server_name="0.0.0.0", server_port=7860)