Spaces:
Running
Running
File size: 7,186 Bytes
1412dfd bc07199 1412dfd bc07199 1412dfd 132d7fb 1412dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import os
import time
import shutil
from pathlib import Path
from typing import Optional
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
# Import your existing inference endpoint implementation
from handler import EndpointHandler
# ------------------------------------------------------------------------------
# Asset setup: download weights/tags/mapping so local filenames are unchanged
# ------------------------------------------------------------------------------
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look
# Optional: Hugging Face token for private repos
HF_TOKEN = (
os.environ.get("HUGGINGFACE_HUB_TOKEN")
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACE_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
REQUIRED_FILES = [
"model_v0.9.pth",
"tags_v0.9_13k.json",
"char_ip_map.json",
]
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
"""
1) snapshot_download the upstream repo (cached by HF Hub)
2) copy the required files into `target_dir` with the exact filenames expected
"""
target = Path(target_dir)
target.mkdir(parents=True, exist_ok=True)
# Only download if something is missing
missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
if not missing:
return
# Download snapshot (optionally filtered to speed up)
snapshot_path = snapshot_download(
repo_id=repo_id,
revision=revision,
allow_patterns=REQUIRED_FILES, # only pull what we need
token=HF_TOKEN, # authenticate if repo is private
)
# Copy files into target_dir with the required names
for fname in REQUIRED_FILES:
src = Path(snapshot_path) / fname
dst = target / fname
if not src.exists():
raise FileNotFoundError(
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
)
shutil.copyfile(src, dst)
# Fetch assets (no-op if they already exist)
ensure_assets(REPO_ID, REVISION, MODEL_DIR)
# ------------------------------------------------------------------------------
# Initialize the handler
# ------------------------------------------------------------------------------
handler = EndpointHandler(MODEL_DIR)
DEVICE_LABEL = f"Device: {handler.device.upper()}"
# ------------------------------------------------------------------------------
# Gradio wiring
# ------------------------------------------------------------------------------
def run_inference(
source_choice: str,
image: Optional[Image.Image],
url: str,
general_threshold: float,
character_threshold: float,
):
if source_choice == "Upload image":
if image is None:
raise gr.Error("Please upload an image.")
inputs = image
else:
if not url or not url.strip():
raise gr.Error("Please provide an image URL.")
inputs = {"url": url.strip()}
data = {
"inputs": inputs,
"parameters": {
"general_threshold": float(general_threshold),
"character_threshold": float(character_threshold),
},
}
started = time.time()
try:
out = handler(data)
except Exception as e:
raise gr.Error(f"Inference error: {e}") from e
latency = round(time.time() - started, 4)
features = ", ".join(sorted(out.get("feature", []))) or "β"
characters = ", ".join(sorted(out.get("character", []))) or "β"
ips = ", ".join(out.get("ip", [])) or "β"
meta = {
"device": handler.device,
"latency_s_total": latency,
**out.get("_timings", {}),
}
return features, characters, ips, meta, out
with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True) as demo:
gr.Markdown(
"""
# PixAI Tagger v0.9 β Gradio Demo
Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run,
then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags.
Configure via env vars:
- `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`)
- `ASSETS_REVISION` (optional)
- `MODEL_DIR` (default: `./assets`)
"""
)
with gr.Row():
gr.Markdown(f"**{DEVICE_LABEL}**")
with gr.Row():
source_choice = gr.Radio(
choices=["Upload image", "From URL"],
value="Upload image",
label="Image source",
)
with gr.Row(variant="panel"):
with gr.Column(scale=2):
image = gr.Image(label="Upload image", type="pil", visible=True, height="500px")
url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=False)
def toggle_inputs(choice):
return (
gr.update(visible=(choice == "Upload image")),
gr.update(visible=(choice == "From URL")),
)
source_choice.change(toggle_inputs, [source_choice], [image, url])
with gr.Column(scale=1):
general_threshold = gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
)
character_threshold = gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
)
run_btn = gr.Button("Run", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Row():
with gr.Column():
gr.Markdown("### Predicted Tags")
features_out = gr.Textbox(label="General tags", lines=4)
characters_out = gr.Textbox(label="Character tags", lines=4)
ip_out = gr.Textbox(label="IP tags", lines=2)
with gr.Column():
gr.Markdown("### Metadata & Raw Output")
meta_out = gr.JSON(label="Timings/Device")
raw_out = gr.JSON(label="Raw JSON")
examples = gr.Examples(
label="Examples (URL mode)",
examples=[
["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85],
],
inputs=[source_choice, image, url, general_threshold, character_threshold],
cache_examples=False,
)
def clear():
return (None, "", 0.30, 0.85, "", "", "", {}, {})
run_btn.click(
run_inference,
inputs=[source_choice, image, url, general_threshold, character_threshold],
outputs=[features_out, characters_out, ip_out, meta_out, raw_out],
api_name="predict",
)
clear_btn.click(
clear,
inputs=None,
outputs=[
image, url, general_threshold, character_threshold,
features_out, characters_out, ip_out, meta_out, raw_out
],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()
|