| """Gradio demo for BiliSakura visual generative foundation models on Hugging Face ZeroGPU.""" |
|
|
| from __future__ import annotations |
|
|
| try: |
| import spaces |
| except ImportError: |
| class _SpacesStub: |
| @staticmethod |
| def GPU(*args, **kwargs): |
| def decorator(fn): |
| return fn |
|
|
| if args and callable(args[0]): |
| return args[0] |
| return decorator |
|
|
| spaces = _SpacesStub() |
|
|
| import traceback |
| from typing import Any |
|
|
| import gradio as gr |
| import torch |
|
|
| from model_catalog import ( |
| COLLECTIONS, |
| MODEL_LABELS, |
| get_profile_by_label, |
| parse_model_label, |
| scheduler_choices_for_profile, |
| uses_native_scheduler, |
| ) |
| from model_loader import ( |
| PIPELINE_MANAGER, |
| _to_float, |
| _to_int, |
| current_scheduler_name, |
| default_class_label_for_pipe, |
| run_inference, |
| scheduler_options_for_profile, |
| ) |
|
|
|
|
| DEFAULT_MODEL = MODEL_LABELS[0] |
| DEFAULT_PROFILE = get_profile_by_label(DEFAULT_MODEL) |
|
|
| INTERVAL_COLLECTIONS = {"iMF-diffusers", "NiT-diffusers", "PixelFlow-diffusers"} |
| PMF_COLLECTION = "pMF-diffusers" |
|
|
|
|
| def _model_info_markdown(profile, pipe=None) -> str: |
| extras = profile.extra_call_kwargs |
| extra_lines = "" |
| if extras: |
| extra_lines = "\n".join(f"- `{key}`: `{value}`" for key, value in extras.items()) |
| extra_lines = f"\n\n**Default extra args**\n{extra_lines}" |
| scheduler_line = ( |
| f"`{current_scheduler_name(pipe)}`" |
| if pipe is not None |
| else "`checkpoint` (load model to see)" |
| ) |
| return ( |
| f"**Hub repo:** [`{profile.hub_model_id}`]({profile.hub_model_url})\n\n" |
| f"- dtype: `{profile.dtype}`\n" |
| f"- default resolution: `{profile.default_height}x{profile.default_width}`\n" |
| f"- checkpoint scheduler: {scheduler_line}\n" |
| f"- GPU size: `{profile.gpu_size}`" |
| f"{extra_lines}" |
| ) |
|
|
|
|
| def _interval_defaults(profile) -> tuple[float, float]: |
| extras = profile.extra_call_kwargs |
| if "guidance_interval_start" in extras: |
| return float(extras["guidance_interval_start"]), float(extras["guidance_interval_end"]) |
| interval = extras.get("guidance_interval", (0.0, 0.7)) |
| return float(interval[0]), float(interval[1]) |
|
|
|
|
| def _build_extra_kwargs( |
| profile, |
| guidance_interval_start: float, |
| guidance_interval_end: float, |
| guidance_interval_min: float, |
| guidance_interval_max: float, |
| noise_scale: float, |
| ) -> dict[str, Any]: |
| if profile.collection == "iMF-diffusers": |
| return { |
| "guidance_interval_start": guidance_interval_start, |
| "guidance_interval_end": guidance_interval_end, |
| } |
| if profile.collection in {"NiT-diffusers", "PixelFlow-diffusers"}: |
| return {"guidance_interval": (guidance_interval_start, guidance_interval_end)} |
| if profile.collection == PMF_COLLECTION: |
| return { |
| "guidance_interval_min": guidance_interval_min, |
| "guidance_interval_max": guidance_interval_max, |
| "noise_scale": noise_scale, |
| } |
| return dict(profile.extra_call_kwargs) |
|
|
|
|
| def _scheduler_config(profile, pipe=None): |
| choices, default = scheduler_options_for_profile(profile, pipe) |
| if uses_native_scheduler(profile) and pipe is None: |
| return gr.update( |
| choices=["checkpoint"], |
| value="checkpoint", |
| interactive=False, |
| info="Uses the checkpoint scheduler (custom, not swappable until loaded)", |
| ) |
| return gr.update( |
| choices=choices, |
| value=default, |
| interactive=not uses_native_scheduler(profile), |
| info=( |
| "Checkpoint scheduler by default; pick another built-in diffusers scheduler to swap" |
| if pipe is not None |
| else "Defaults to checkpoint scheduler after Load model; optional built-in swap" |
| ), |
| ) |
|
|
|
|
| def _config_from_profile(profile, pipe=None): |
| g_start, g_end = _interval_defaults(profile) |
| extras = profile.extra_call_kwargs |
| show_interval = profile.collection in INTERVAL_COLLECTIONS |
| show_pmf = profile.collection == PMF_COLLECTION |
| return ( |
| _model_info_markdown(profile, pipe), |
| gr.update(value=profile.default_class_label), |
| gr.update(value=profile.default_seed), |
| gr.update(value=profile.default_steps, maximum=profile.max_steps), |
| gr.update(value=profile.default_guidance), |
| _scheduler_config(profile, pipe), |
| gr.update(value=profile.default_height or profile.infer_resolution()), |
| gr.update(value=profile.default_width or profile.infer_resolution()), |
| gr.update(value=g_start), |
| gr.update(value=g_end), |
| gr.update(value=float(extras.get("guidance_interval_min", 0.2))), |
| gr.update(value=float(extras.get("guidance_interval_max", 0.6))), |
| gr.update(value=float(extras.get("noise_scale", 4.0))), |
| gr.update(visible=show_interval, open=show_interval), |
| gr.update(visible=show_pmf, open=show_pmf), |
| ) |
|
|
|
|
| def on_model_change(model_label: str): |
| return _config_from_profile(get_profile_by_label(model_label)) |
|
|
|
|
| def _coerce_inference_inputs( |
| class_label: Any, |
| seed: Any, |
| num_steps: Any, |
| guidance_scale: Any, |
| height: Any, |
| width: Any, |
| guidance_interval_start: Any, |
| guidance_interval_end: Any, |
| guidance_interval_min: Any, |
| guidance_interval_max: Any, |
| noise_scale: Any, |
| ) -> tuple[str, int, int, float, int, int, float, float, float, float, float]: |
| return ( |
| str(class_label or "").strip(), |
| _to_int(seed, default=42), |
| _to_int(num_steps, default=50), |
| _to_float(guidance_scale, default=4.0), |
| _to_int(height, default=256), |
| _to_int(width, default=256), |
| _to_float(guidance_interval_start, default=0.0), |
| _to_float(guidance_interval_end, default=0.7), |
| _to_float(guidance_interval_min, default=0.2), |
| _to_float(guidance_interval_max, default=0.6), |
| _to_float(noise_scale, default=4.0), |
| ) |
|
|
|
|
| def _gpu_duration( |
| model_label: str, |
| class_label: str, |
| seed: int, |
| num_steps: int, |
| guidance_scale: float, |
| scheduler: str, |
| height: int, |
| width: int, |
| guidance_interval_start: float, |
| guidance_interval_end: float, |
| guidance_interval_min: float, |
| guidance_interval_max: float, |
| noise_scale: float, |
| ) -> int: |
| profile = get_profile_by_label(model_label) |
| num_steps = _to_int(num_steps, default=profile.default_steps) |
| step_budget = num_steps if not profile.steps_are_list else max(num_steps, 40) |
| base = 45 if profile.gpu_size == "large" else 90 |
| return int(min(300, max(base, step_budget * 0.6 + 30))) |
|
|
|
|
| def _load_model_core(model_label: str) -> tuple[str, str | None]: |
| collection, variant = parse_model_label(model_label) |
| message, _ = PIPELINE_MANAGER.load(collection, variant) |
| PIPELINE_MANAGER.move_to_cuda() |
| pipe = PIPELINE_MANAGER.pipe |
| profile = get_profile_by_label(model_label) |
| suggested_label = default_class_label_for_pipe(pipe, profile) if pipe is not None else None |
| return message, suggested_label |
|
|
|
|
| @spaces.GPU(size="xlarge", duration=120) |
| def _load_on_gpu(model_label: str) -> tuple[str, str | None]: |
| return _load_model_core(model_label) |
|
|
|
|
| def load_model(model_label: str): |
| try: |
| message, suggested_label = _load_on_gpu(model_label) |
| except Exception as exc: |
| raise gr.Error(f"Failed to load `{model_label}`: {exc}") from exc |
| profile = get_profile_by_label(model_label) |
| config = _config_from_profile(profile, PIPELINE_MANAGER.pipe) |
| if suggested_label: |
| config = list(config) |
| config[1] = gr.update(value=suggested_label) |
| config = tuple(config) |
| return (message, *config) |
|
|
|
|
| @spaces.GPU(size="xlarge", duration=_gpu_duration) |
| def _generate_on_gpu( |
| model_label: str, |
| class_label: str, |
| seed: int, |
| num_steps: int, |
| guidance_scale: float, |
| scheduler: str, |
| height: int, |
| width: int, |
| guidance_interval_start: float, |
| guidance_interval_end: float, |
| guidance_interval_min: float, |
| guidance_interval_max: float, |
| noise_scale: float, |
| ): |
| ( |
| class_label, |
| seed, |
| num_steps, |
| guidance_scale, |
| height, |
| width, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| ) = _coerce_inference_inputs( |
| class_label, |
| seed, |
| num_steps, |
| guidance_scale, |
| height, |
| width, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| ) |
|
|
| profile = get_profile_by_label(model_label) |
| collection, variant = parse_model_label(model_label) |
| if PIPELINE_MANAGER.loaded_label != model_label or PIPELINE_MANAGER.pipe is None: |
| PIPELINE_MANAGER.load(collection, variant) |
| PIPELINE_MANAGER.move_to_cuda() |
|
|
| pipe = PIPELINE_MANAGER.pipe |
| if pipe is None: |
| raise gr.Error(f"Model `{model_label}` is not loaded.") |
|
|
| extra_kwargs = _build_extra_kwargs( |
| profile, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| ) |
| return run_inference( |
| profile, |
| pipe, |
| class_label=class_label, |
| seed=seed, |
| num_steps=num_steps, |
| guidance_scale=guidance_scale, |
| height=height, |
| width=width, |
| scheduler_name=scheduler, |
| extra_kwargs=extra_kwargs, |
| ) |
|
|
|
|
| def generate( |
| model_label: str, |
| class_label: str, |
| seed: int, |
| num_steps: int, |
| guidance_scale: float, |
| scheduler: str, |
| height: int, |
| width: int, |
| guidance_interval_start: float, |
| guidance_interval_end: float, |
| guidance_interval_min: float, |
| guidance_interval_max: float, |
| noise_scale: float, |
| ): |
| try: |
| image = _generate_on_gpu( |
| model_label, |
| class_label, |
| seed, |
| num_steps, |
| guidance_scale, |
| scheduler, |
| height, |
| width, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| ) |
| except Exception as exc: |
| detail = traceback.format_exc(limit=6).strip() |
| raise gr.Error(f"Generation failed for `{model_label}`: {exc}\n\n{detail}") from exc |
| label = PIPELINE_MANAGER.loaded_label or model_label |
| return f"Generated with `{label}`.", image |
|
|
|
|
| def build_demo() -> gr.Blocks: |
| g_start, g_end = _interval_defaults(DEFAULT_PROFILE) |
| extras = DEFAULT_PROFILE.extra_call_kwargs |
| default_scheduler_choices = ["checkpoint", *scheduler_choices_for_profile(DEFAULT_PROFILE)] |
| default_scheduler_value = "checkpoint" |
|
|
| with gr.Blocks(title="BiliSakura Visual Generation Models") as demo: |
| gr.Markdown( |
| """ |
| # BiliSakura Visual Generative Foundation Models |
| |
| Class-conditional image generation for [`BiliSakura/*-diffusers`](https://huggingface.co/BiliSakura) |
| on Hugging Face **ZeroGPU**. |
| """ |
| ) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=5): |
| model = gr.Dropdown( |
| MODEL_LABELS, |
| value=DEFAULT_MODEL, |
| label="Model", |
| info="Select a checkpoint, then configure inference args below", |
| ) |
| model_info = gr.Markdown(_model_info_markdown(DEFAULT_PROFILE)) |
|
|
| with gr.Accordion("Inference config", open=True): |
| class_label = gr.Textbox( |
| label="class_labels", |
| value=DEFAULT_PROFILE.default_class_label, |
| info="ImageNet class id (e.g. 207) or any synonym from id2label (e.g. golden retriever, tabby)", |
| ) |
| with gr.Row(): |
| seed = gr.Number(label="seed", value=DEFAULT_PROFILE.default_seed, precision=0) |
| num_steps = gr.Slider( |
| label="num_inference_steps", |
| minimum=1, |
| maximum=DEFAULT_PROFILE.max_steps, |
| step=1, |
| value=DEFAULT_PROFILE.default_steps, |
| ) |
| guidance_scale = gr.Slider( |
| label="guidance_scale", |
| minimum=0.0, |
| maximum=20.0, |
| step=0.1, |
| value=DEFAULT_PROFILE.default_guidance, |
| ) |
| scheduler = gr.Dropdown( |
| label="scheduler", |
| choices=default_scheduler_choices or ["checkpoint"], |
| value=default_scheduler_value, |
| interactive=not uses_native_scheduler(DEFAULT_PROFILE), |
| info="Defaults to checkpoint scheduler after Load model", |
| ) |
| with gr.Row(): |
| height = gr.Slider( |
| label="height", |
| minimum=128, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_PROFILE.default_height or 256, |
| ) |
| width = gr.Slider( |
| label="width", |
| minimum=128, |
| maximum=1024, |
| step=16, |
| value=DEFAULT_PROFILE.default_width or 256, |
| ) |
|
|
| with gr.Accordion( |
| "Advanced: guidance interval", |
| open=False, |
| visible=DEFAULT_PROFILE.collection in INTERVAL_COLLECTIONS, |
| ) as interval_accordion: |
| with gr.Row(): |
| guidance_interval_start = gr.Slider( |
| label="guidance_interval_start / [0]", |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=g_start, |
| ) |
| guidance_interval_end = gr.Slider( |
| label="guidance_interval_end / [1]", |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=g_end, |
| ) |
|
|
| with gr.Accordion( |
| "Advanced: pMF args", |
| open=False, |
| visible=DEFAULT_PROFILE.collection == PMF_COLLECTION, |
| ) as pmf_accordion: |
| with gr.Row(): |
| guidance_interval_min = gr.Slider( |
| label="guidance_interval_min", |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=float(extras.get("guidance_interval_min", 0.2)), |
| ) |
| guidance_interval_max = gr.Slider( |
| label="guidance_interval_max", |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=float(extras.get("guidance_interval_max", 0.6)), |
| ) |
| noise_scale = gr.Slider( |
| label="noise_scale", |
| minimum=0.0, |
| maximum=10.0, |
| step=0.1, |
| value=float(extras.get("noise_scale", 4.0)), |
| ) |
|
|
| with gr.Row(): |
| load_btn = gr.Button("Load model", variant="secondary") |
| generate_btn = gr.Button("Generate", variant="primary") |
|
|
| status = gr.Textbox(label="Status", interactive=False, lines=2) |
| gr.Markdown( |
| f"**Catalog:** {len(MODEL_LABELS)} variants · {len(COLLECTIONS)} families" |
| ) |
|
|
| with gr.Column(scale=6): |
| output = gr.Image( |
| label="Generated image", |
| type="pil", |
| height=720, |
| elem_classes=["output-image"], |
| ) |
| gr.Examples( |
| examples=[ |
| [DEFAULT_MODEL, "golden retriever", 42], |
| [DEFAULT_MODEL, "207", 0], |
| [DEFAULT_MODEL, "tabby", 123], |
| [DEFAULT_MODEL, "tabby, tabby cat", 456], |
| ], |
| inputs=[model, class_label, seed], |
| label="Quick examples", |
| ) |
|
|
| inference_inputs = [ |
| model, |
| class_label, |
| seed, |
| num_steps, |
| guidance_scale, |
| scheduler, |
| height, |
| width, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| ] |
| config_outputs = [ |
| model_info, |
| class_label, |
| seed, |
| num_steps, |
| guidance_scale, |
| scheduler, |
| height, |
| width, |
| guidance_interval_start, |
| guidance_interval_end, |
| guidance_interval_min, |
| guidance_interval_max, |
| noise_scale, |
| interval_accordion, |
| pmf_accordion, |
| ] |
|
|
| model.change(on_model_change, inputs=model, outputs=config_outputs) |
| load_btn.click(load_model, inputs=model, outputs=[status, *config_outputs]) |
| generate_btn.click(generate, inputs=inference_inputs, outputs=[status, output]) |
| demo.load(on_model_change, inputs=model, outputs=config_outputs) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
| if __name__ == "__main__": |
| if not torch.cuda.is_available(): |
| print("CUDA is not available locally; ZeroGPU Spaces will provide GPU at inference time.") |
| demo.queue(max_size=8).launch() |
|
|