Spaces:
Running
Running
| import os | |
| import time | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| # Import your existing inference endpoint implementation | |
| from handler import EndpointHandler | |
| # ------------------------------------------------------------------------------ | |
| # Asset setup: download weights/tags/mapping so local filenames are unchanged | |
| # ------------------------------------------------------------------------------ | |
| REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") | |
| REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit | |
| MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look | |
| # Optional: Hugging Face token for private repos | |
| HF_TOKEN = ( | |
| os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| or os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGINGFACE_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| REQUIRED_FILES = [ | |
| "model_v0.9.pth", | |
| "tags_v0.9_13k.json", | |
| "char_ip_map.json", | |
| ] | |
| def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): | |
| """ | |
| 1) snapshot_download the upstream repo (cached by HF Hub) | |
| 2) copy the required files into `target_dir` with the exact filenames expected | |
| """ | |
| target = Path(target_dir) | |
| target.mkdir(parents=True, exist_ok=True) | |
| # Only download if something is missing | |
| missing = [f for f in REQUIRED_FILES if not (target / f).exists()] | |
| if not missing: | |
| return | |
| # Download snapshot (optionally filtered to speed up) | |
| snapshot_path = snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| allow_patterns=REQUIRED_FILES, # only pull what we need | |
| token=HF_TOKEN, # authenticate if repo is private | |
| ) | |
| # Copy files into target_dir with the required names | |
| for fname in REQUIRED_FILES: | |
| src = Path(snapshot_path) / fname | |
| dst = target / fname | |
| if not src.exists(): | |
| raise FileNotFoundError( | |
| f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" | |
| ) | |
| shutil.copyfile(src, dst) | |
| # Fetch assets (no-op if they already exist) | |
| ensure_assets(REPO_ID, REVISION, MODEL_DIR) | |
| # ------------------------------------------------------------------------------ | |
| # Initialize the handler | |
| # ------------------------------------------------------------------------------ | |
| handler = EndpointHandler(MODEL_DIR) | |
| DEVICE_LABEL = f"Device: {handler.device.upper()}" | |
| # ------------------------------------------------------------------------------ | |
| # Gradio wiring | |
| # ------------------------------------------------------------------------------ | |
| def run_inference( | |
| source_choice: str, | |
| image: Optional[Image.Image], | |
| url: str, | |
| general_threshold: float, | |
| character_threshold: float, | |
| ): | |
| if source_choice == "Upload image": | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| inputs = image | |
| else: | |
| if not url or not url.strip(): | |
| raise gr.Error("Please provide an image URL.") | |
| inputs = {"url": url.strip()} | |
| data = { | |
| "inputs": inputs, | |
| "parameters": { | |
| "general_threshold": float(general_threshold), | |
| "character_threshold": float(character_threshold), | |
| }, | |
| } | |
| started = time.time() | |
| try: | |
| out = handler(data) | |
| except Exception as e: | |
| raise gr.Error(f"Inference error: {e}") from e | |
| latency = round(time.time() - started, 4) | |
| features = ", ".join(sorted(out.get("feature", []))) or "β" | |
| characters = ", ".join(sorted(out.get("character", []))) or "β" | |
| ips = ", ".join(out.get("ip", [])) or "β" | |
| meta = { | |
| "device": handler.device, | |
| "latency_s_total": latency, | |
| **out.get("_timings", {}), | |
| } | |
| return features, characters, ips, meta, out | |
| with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True) as demo: | |
| gr.Markdown( | |
| """ | |
| # PixAI Tagger v0.9 β Gradio Demo | |
| Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run, | |
| then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags. | |
| Configure via env vars: | |
| - `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`) | |
| - `ASSETS_REVISION` (optional) | |
| - `MODEL_DIR` (default: `./assets`) | |
| """ | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(f"**{DEVICE_LABEL}**") | |
| with gr.Row(): | |
| source_choice = gr.Radio( | |
| choices=["Upload image", "From URL"], | |
| value="Upload image", | |
| label="Image source", | |
| ) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| image = gr.Image(label="Upload image", type="pil", visible=True, height="500px") | |
| url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=False) | |
| def toggle_inputs(choice): | |
| return ( | |
| gr.update(visible=(choice == "Upload image")), | |
| gr.update(visible=(choice == "From URL")), | |
| ) | |
| source_choice.change(toggle_inputs, [source_choice], [image, url]) | |
| with gr.Column(scale=1): | |
| general_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" | |
| ) | |
| character_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Predicted Tags") | |
| features_out = gr.Textbox(label="General tags", lines=4) | |
| characters_out = gr.Textbox(label="Character tags", lines=4) | |
| ip_out = gr.Textbox(label="IP tags", lines=2) | |
| with gr.Column(): | |
| gr.Markdown("### Metadata & Raw Output") | |
| meta_out = gr.JSON(label="Timings/Device") | |
| raw_out = gr.JSON(label="Raw JSON") | |
| examples = gr.Examples( | |
| label="Examples (URL mode)", | |
| examples=[ | |
| ["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85], | |
| ], | |
| inputs=[source_choice, image, url, general_threshold, character_threshold], | |
| cache_examples=False, | |
| ) | |
| def clear(): | |
| return (None, "", 0.30, 0.85, "", "", "", {}, {}) | |
| run_btn.click( | |
| run_inference, | |
| inputs=[source_choice, image, url, general_threshold, character_threshold], | |
| outputs=[features_out, characters_out, ip_out, meta_out, raw_out], | |
| api_name="predict", | |
| ) | |
| clear_btn.click( | |
| clear, | |
| inputs=None, | |
| outputs=[ | |
| image, url, general_threshold, character_threshold, | |
| features_out, characters_out, ip_out, meta_out, raw_out | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |