File size: 7,186 Bytes
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc07199
 
 
 
 
 
 
 
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc07199
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132d7fb
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import os
import time
import shutil
from pathlib import Path
from typing import Optional

import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image

# Import your existing inference endpoint implementation
from handler import EndpointHandler


# ------------------------------------------------------------------------------
# Asset setup: download weights/tags/mapping so local filenames are unchanged
# ------------------------------------------------------------------------------

REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
REVISION = os.environ.get("ASSETS_REVISION")  # optional pin, e.g. "main" or a commit
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets")  # where the handler will look

# Optional: Hugging Face token for private repos
HF_TOKEN = (
    os.environ.get("HUGGINGFACE_HUB_TOKEN")
    or os.environ.get("HF_TOKEN")
    or os.environ.get("HUGGINGFACE_TOKEN")
    or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)

REQUIRED_FILES = [
    "model_v0.9.pth",
    "tags_v0.9_13k.json",
    "char_ip_map.json",
]

def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
    """
    1) snapshot_download the upstream repo (cached by HF Hub)
    2) copy the required files into `target_dir` with the exact filenames expected
    """
    target = Path(target_dir)
    target.mkdir(parents=True, exist_ok=True)

    # Only download if something is missing
    missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
    if not missing:
        return

    # Download snapshot (optionally filtered to speed up)
    snapshot_path = snapshot_download(
        repo_id=repo_id,
        revision=revision,
        allow_patterns=REQUIRED_FILES,  # only pull what we need
        token=HF_TOKEN,  # authenticate if repo is private
    )

    # Copy files into target_dir with the required names
    for fname in REQUIRED_FILES:
        src = Path(snapshot_path) / fname
        dst = target / fname
        if not src.exists():
            raise FileNotFoundError(
                f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
            )
        shutil.copyfile(src, dst)


# Fetch assets (no-op if they already exist)
ensure_assets(REPO_ID, REVISION, MODEL_DIR)


# ------------------------------------------------------------------------------
# Initialize the handler
# ------------------------------------------------------------------------------

handler = EndpointHandler(MODEL_DIR)
DEVICE_LABEL = f"Device: {handler.device.upper()}"


# ------------------------------------------------------------------------------
# Gradio wiring
# ------------------------------------------------------------------------------

def run_inference(
    source_choice: str,
    image: Optional[Image.Image],
    url: str,
    general_threshold: float,
    character_threshold: float,
):
    if source_choice == "Upload image":
        if image is None:
            raise gr.Error("Please upload an image.")
        inputs = image
    else:
        if not url or not url.strip():
            raise gr.Error("Please provide an image URL.")
        inputs = {"url": url.strip()}

    data = {
        "inputs": inputs,
        "parameters": {
            "general_threshold": float(general_threshold),
            "character_threshold": float(character_threshold),
        },
    }

    started = time.time()
    try:
        out = handler(data)
    except Exception as e:
        raise gr.Error(f"Inference error: {e}") from e
    latency = round(time.time() - started, 4)

    features = ", ".join(sorted(out.get("feature", []))) or "β€”"
    characters = ", ".join(sorted(out.get("character", []))) or "β€”"
    ips = ", ".join(out.get("ip", [])) or "β€”"

    meta = {
        "device": handler.device,
        "latency_s_total": latency,
        **out.get("_timings", {}),
    }

    return features, characters, ips, meta, out


with gr.Blocks(title="PixAI Tagger v0.9 β€” Demo", fill_height=True) as demo:
    gr.Markdown(
        """
        # PixAI Tagger v0.9 β€” Gradio Demo
        Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run,
        then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags.

        Configure via env vars:
        - `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`)
        - `ASSETS_REVISION` (optional)
        - `MODEL_DIR` (default: `./assets`)
        """
    )
    with gr.Row():
        gr.Markdown(f"**{DEVICE_LABEL}**")

    with gr.Row():
        source_choice = gr.Radio(
            choices=["Upload image", "From URL"],
            value="Upload image",
            label="Image source",
        )

    with gr.Row(variant="panel"):
        with gr.Column(scale=2):
            image = gr.Image(label="Upload image", type="pil", visible=True, height="500px")
            url = gr.Textbox(label="Image URL", placeholder="https://…", visible=False)

            def toggle_inputs(choice):
                return (
                    gr.update(visible=(choice == "Upload image")),
                    gr.update(visible=(choice == "From URL")),
                )

            source_choice.change(toggle_inputs, [source_choice], [image, url])

        with gr.Column(scale=1):
            general_threshold = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
            )
            character_threshold = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
            )
            run_btn = gr.Button("Run", variant="primary")
            clear_btn = gr.Button("Clear")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Predicted Tags")
            features_out = gr.Textbox(label="General tags", lines=4)
            characters_out = gr.Textbox(label="Character tags", lines=4)
            ip_out = gr.Textbox(label="IP tags", lines=2)

        with gr.Column():
            gr.Markdown("### Metadata & Raw Output")
            meta_out = gr.JSON(label="Timings/Device")
            raw_out = gr.JSON(label="Raw JSON")

    examples = gr.Examples(
        label="Examples (URL mode)",
        examples=[
            ["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85],
        ],
        inputs=[source_choice, image, url, general_threshold, character_threshold],
        cache_examples=False,
    )

    def clear():
        return (None, "", 0.30, 0.85, "", "", "", {}, {})

    run_btn.click(
        run_inference,
        inputs=[source_choice, image, url, general_threshold, character_threshold],
        outputs=[features_out, characters_out, ip_out, meta_out, raw_out],
        api_name="predict",
    )
    clear_btn.click(
        clear,
        inputs=None,
        outputs=[
            image, url, general_threshold, character_threshold,
            features_out, characters_out, ip_out, meta_out, raw_out
        ],
    )

if __name__ == "__main__":
    demo.queue(max_size=8).launch()